Dot-product attention enhancements: MHA, MQA, GQA, and MHLA

Starting from dot-product attention, I present and give intuitive understanding of the main variants of the attention mechanism, namely: Multi-Head, Multi-Query, Grouped-Query, and Multi-Head Latent attentions.
Transformers explained
Author

Oleguer Canal

Published

November 5, 2024

Hi! Today we have a very shallow post. We’ll go over the main “wrappers” around dot-product attention and we’ll call it a day 🙃 The main purpose is for me to have a repository of these structures that aligns wth my mental model of them1.

1 So don’t need to exert a trifle of thought by making sense of someone else’s diagrams.

This post focusses on the attention parts of the transformer (circled parts).

Dot-Product Attention

At the core of the circled multi-ehad attention layers we find the dot-product attention. I talk extensively2 about this in the sequence post. Here you can find an over-simplified diagram of how it works3:

2 Probably too extensively

3 I focus on cross-attention because it is more general.

Train-time

The formula goes like this: \[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{ Q K^T }{\sqrt{f}}\right)V \]

Which can be represented as:

1. We compute the attention matrix: Each row indicates how related its corresponding query is to each key.
2. Each row linearly combines the values according to the previous correlations. Image by author.

Let’s assume Q, K are distributed normally:

\[ Q_{t_q, i}, K_{t_k, j} \sim \mathcal{N} (0, 1) \]

We then have that:

\[ \text{var} \left( Q_{t_q, i} \cdot K_{t_k, j} \right) = 1 \]

Thus, when computing the dot product of \(\vec{Q_{t_q}} \cdot \vec{K{t_q}}\), the variance grows linearly:

\[ \text{var} \left( \sum_{i=1:f} Q_{t_q, i} \cdot K_{t_k, i} \right) = f \]

Meaning we get a \(\text{std} \vec{Q_{t_q}} \cdot \vec{K{t_q}} = \sqrt{f}\).

This means that as the embedding dimension increases, the attention matrix values get typically larger (in absolute value). For instance, if \(f = 1024\), we have that \(\text{std} \approx 32\). Meaning typical logits lie in the \([-32, 32]\) range.

This dependence on \(f\) can be problematic when applying the \(\text{softmax}\) operation: very large values dominate the exponential and collapse the rest: \(e^{32} \approx 10^{14}\). Making the softmax become a one-hot encoder (and gradient killer). Dividing by the \(\text{std}\) significantly reduces this problem :)

Inference-time

This one is for self-attn but the idea is the same.

1. We project the input into q, k, v
2. We compare q with all previous k’s (cached)
3. We do the weighted average of all previous v’s (also cached).
Notice it is the same as before but having only 1 query. Image by author.

Attention extensions

Here I go over the most influential variants of the dot-product attention.

2017. MHA: Multi-Head Attention

The main motivation is to broadcast the dot-product attention operation in parallel so we can attend to different things at the same time.

After projecting, we simply slice Q, K, V and broadcast the previous operation. Image by author.

2019. MQA: Multi-Query Attention

I’d call this single KV-attention rather than multi-query, but whatever. MQA’s main idea4 is to reduce memory footprint by reducing the dimensionality of the KV-cache5.

4 And all the following variants.

5 To learn more how to compress in the temporal dimension, check out my sequence modelling post

K, V are directly projected into a \(\frac{f}{n}\) dimensional space, while we still project Q into \(f\) dimension and split into different heads. Image by author.

2023. GQA: Grouped-Query Attention

GQA is a middle-ground between multi-head and multi-query. We still have n heads for queries, but we split the KV-cache into g groups. Thus, we have n/g heads per group.

In this example n=4 heads and g=2 groups. Image by author.

I made the poor decision of swapping Q and V colors. I am not repeating them. Figure from GQA paper.

2024. MHLA: Multi-Head Latent Attention

DeepSeek-v2 paper presents a variant that attempts to decrease memory footprint using two different techniques

  1. Linearly down-projecting the input into a lower-dimensional space, from which one can recover Q and K. The main advantage is that we can just cache this projection \(c_{KV}\) which is much smaller than the naive complete KV-cache at inference-time.

  2. It applies the same principle for Q to save memory at train-time.

It also uses RoPE, which I talk about in my positional encoding post.

RoPE + down-projecting inputs before applying attention to be able to compress KV-cache at inference time. Image by author.