Attention, Multi-Head, Complexity, KV Caching — Broken down and Explained from the basics
Defined once here. Refer back whenever a symbol appears and you're unsure what it means.
| Symbol | What It Means | Typical Value (GPT-2 Small) |
|---|---|---|
| $B$ | Batch size — number of sequences processed simultaneously | 32 |
| $S$ | Sequence length — number of tokens in one sequence | 1024 |
| $d_{model}$ | Model dimension — size of each token's embedding vector | 768 |
| $n_{heads}$ | Number of attention heads | 12 |
| $d_{head}$ | Dimension per head — always $d_{model} / n_{heads}$ | 64 |
| $L$ | Number of transformer layers (depth) | 12 |
| $X$ | Input matrix — shape $(S, d_{model})$, all tokens stacked | — |
| $Q$ | Query matrix — shape $(S, d_{head})$ per head | — |
| $K$ | Key matrix — shape $(S, d_{head})$ per head | — |
| $V$ | Value matrix — shape $(S, d_{head})$ per head | — |
| $W_Q, W_K, W_V$ | Learned projection weight matrices — shape $(d_{model}, d_{model})$ each | — |
| $W_O$ | Output projection weight matrix — shape $(d_{model}, d_{model})$ | — |
| $t$ | Current generation step (which token is being produced) | — |
| FLOPs | Floating point operations — how we measure compute cost | — |
Before transformers, sequence modelling meant RNNs and LSTMs. They process one token at a time, left to right, updating a hidden state $h_t$ at each step:
Instead of compressing the whole sequence into a hidden state, let every token directly query every other token and decide how much to borrow from it.
For "The cat sat on the mat because it was tired" — when processing "it", attention scores every token directly:
The output for "it" = $0.71 \times \text{rep}(\text{"cat"}) + 0.08 \times \text{rep}(\text{"mat"}) + \ldots$ — a context-aware blend.
Every token produces three vectors via learned linear projections $W_Q, W_K, W_V$:
| Vector | Database Analogy | Role in Attention |
|---|---|---|
| Query $Q$ | "What am I looking for?" | Current token's search request |
| Key $K$ | "What do I contain?" | Each token advertising its content |
| Value $V$ | "What do I actually give you?" | Information to aggregate if selected |
In matrix form (all $S$ tokens at once):
"Why √d_head? Why not d_head? Why not nothing?"
Assume $Q$ and $K$ are initialised from $\mathcal{N}(0,1)$. Each element has mean 0, variance 1. The dot product $Q_i \cdot K_j = \sum_{l=1}^{d_{head}} Q_{il} \cdot K_{jl}$ is a sum of $d_{head}$ products.
If $a \sim \mathcal{N}(0,1)$ and $b \sim \mathcal{N}(0,1)$ independently: $\text{Var}(ab) = 1$.
Sum of $d_{head}$ such independent terms has variance $= d_{head}$. Standard deviation $= \sqrt{d_{head}}$.
With $d_{head}=64$: std dev = 8, so dot products are typically in $[-24, +24]$.
Feed $[-24, +24]$ into softmax:
Divide by $\sqrt{d_{head}}$ first:
Scores in $[-24,+24]$ become $[-3,+3]$.
softmax([3, 2, 0, -1]) = [0.61, 0.23, 0.11, 0.04] — smooth distribution, healthy gradients.
$\sqrt{d_{head}}$ is exactly the standard deviation of the unscaled dot product — dividing by it restores unit variance.
"Why split into n_heads heads? Why not just use one big attention with full d_model?"
Splitting $d_{model}$ into $n_{heads}$ heads of $d_{head} = d_{model}/n_{heads}$ each uses the same number of parameters. You're not adding capacity — you're restructuring it to encourage specialisation.
Tracing exact shapes through GPT-2 Small: $B=1$, $S=1024$, $d_{model}=768$, $n_{heads}=12$, $d_{head}=64$.
"If there are 12 heads, are there 12 separate W_Q matrices or one shared one?"
$W_Q$ is $(768, 768)$. After the projection we reshape to $(S, n_{heads}, d_{head})$ to split by head. Conceptually it's $n_{heads}$ stacked projections of $(768, 64)$. One matrix for GPU efficiency; split after multiplication.
$d_{model}=768$, $n_{heads}=12$, $d_{head}=64$.
$W_Q$: $768 \times 768 = 589{,}824$ | $W_K$: same | $W_V$: same | $W_O$: same
Total per layer: $\approx 2.36M$ parameters — identical for 1 head or 12 heads.
"I've heard O(S²·d_model) but I don't know how to count the cost of multiplying two matrices."
Dot product of two length-$k$ vectors: $k$ multiplications + $(k-1)$ additions $\approx 2k$ operations.
$[1,2,3]\cdot[4,5,6] = 4+10+18=32$ → 3 multiplications + 2 additions = 5 ops $\approx 2k$ for $k=3$.
Matrix multiply $(A: m \times k)$ by $(B: k \times p)$: result is $(m \times p)$. Each of the $m \times p$ entries is a length-$k$ dot product.
$(4\times3)\times(3\times5)$ → result $(4,5)$. Number of dot products: $4\times5=20$. Each length 3.
Total: $4\times5\times3=60$ ops = $O(m\cdot k\cdot p)$.
Rule: multiply all three dimensions of the two matrices together.
$X W_Q$: $(S, d_{model}) \times (d_{model}, d_{model}) \to (S, d_{model})$. Cost $= O(S \cdot d_{model}^2)$. Four projections total: $O(S \cdot d_{model}^2)$.
Per head: $Q$ is $(S, d_{head})$, $K^\top$ is $(d_{head}, S)$. Result: $(S, S)$.
Cost per head $= O(S \cdot d_{head} \cdot S) = O(S^2 d_{head})$.
All $n_{heads}$ heads: $O(n_{heads} \cdot S^2 d_{head}) = O(S^2 \cdot n_{heads} \cdot d_{head}) = O(S^2 \cdot d_{model})$ since $n_{heads} \cdot d_{head} = d_{model}$.
$S=1024$, $d_{head}=64$, $n_{heads}=12$.
Per head: $1024 \times 1024 \times 64 = 67M$ FLOPs.
All 12 heads: $12 \times 67M = 805M$ FLOPs for attention scores alone.
Projection cost: $1024 \times 768 \times 768 = 603M$ FLOPs. Both are at the same order of magnitude.
$(S, S) \times (S, d_{head})$ per head. Cost $= O(S^2 d_{head})$ per head $= O(S^2 d_{model})$ total. Same as $QK^\top$.
Double the sequence length → 4× the compute and memory for attention. This is why 100k-token contexts are expensive, and why FlashAttention, sparse attention, and linear attention were invented.
"Compute is O(S²·d_model). Why is memory only O(S²)? Shouldn't memory scale with d_model too?"
Compute counts operations performed. Memory counts values that must be stored simultaneously.
The attention score matrix has shape $(B, n_{heads}, S, S)$. That's $B \times n_{heads} \times S^2$ numbers. The $d_{head}$ dimension was consumed by the dot product — it doesn't appear in the output matrix. So storing the score matrix costs $O(S^2)$ per batch element per head.
$S=1024$, $n_{heads}=12$, float32 (4 bytes each).
Score matrix: $12 \times 1024 \times 1024 = 12{,}582{,}912$ floats $\approx$ 48 MB per layer.
During training with backprop ($L=12$ layers): $12 \times 48 = 576$ MB just for attention scores.
Score matrix scales as $S^2$.
$S=1024$ → 48 MB/layer. $S=4096$ → 768 MB/layer. $S=32{,}768$ → 49 GB/layer.
At $S=32k$, storing attention scores is infeasible on most GPUs. This is the exact problem FlashAttention solves by recomputing scores on the fly rather than storing them.
KV caching is an inference-time optimisation only — it doesn't apply during training.
During autoregressive generation, tokens are produced one at a time:
At each step, in every transformer layer, the model computes $K$ and $V$ for every token in the current input.
Generating the 100th token. Context = 99 previous tokens.
Model computes $K$ and $V$ for all 99 tokens (including "The", "cat", etc.).
Generating the 101st token. Context = 100 tokens.
Model computes $K$ and $V$ for all 100 tokens — including the same 99 from before.
$K$ of "The" at step 100 = $K$ of "The" at step 1 = identical. Recomputed for no reason.
For a 1000-token generation: token 1000 triggers recomputation of K,V for 999 prior tokens. Total redundant projection operations: $1+2+\ldots+999 \approx O(S^2)$.
"Why cache K and V specifically? What's special about them vs Q?"
$K$ and $V$ are deterministic functions of their token — same token always produces the same $K$ and $V$. $Q$ is the current token's fresh question to the world — always new. Cache the answers, not the questions.
"I've seen KV cache size quoted as 2·S·L·d_model (roughly). Where does each term come from?"
$S=1024$, $L=12$, $d_{model}=768$, float16.
Numbers: $2 \times 1024 \times 12 \times 768 = 18{,}874{,}368$
Bytes: $\times 2 = 37{,}748{,}736 \approx$ 36 MB. Tiny.
$S=4096$, $L=80$, $d_{model}=8192$, float16.
Bytes: $4 \times 4096 \times 80 \times 8192 = 10{,}737{,}418{,}240 \approx$ 10 GB just for KV cache.
Model weights: ~130 GB additional. For $S=32{,}768$ (long context): KV cache alone → ~80 GB.
"K and V are just two matrix multiplies. How much can reusing them actually save?"
At generation step $t$, the input context has $t-1$ tokens.
Without KV cache: compute $K$ and $V$ for all $t-1$ tokens per layer. Each is $(t-1, d_{model}) \times (d_{model}, d_{model})$: cost $= O((t-1) \cdot d_{model}^2)$ per layer.
With KV cache: compute $K$ and $V$ only for the new token. That's $(1, d_{model}) \times (d_{model}, d_{model})$: cost $= O(d_{model}^2)$ per layer — independent of $t$.
$t=500$, $d_{model}=768$, $L=12$.
Without cache — K,V projections all layers: $499 \times 768 \times 768 \times 12 \times 2 \approx 5.6B$ FLOPs
With cache — K,V for 1 new token all layers: $1 \times 768 \times 768 \times 12 \times 2 \approx 14M$ FLOPs
Savings: ~400× fewer FLOPs for K,V projections at step 500.
Total savings across all $S$ generation steps:
"But we still run attention over the full cached K and V every step. Doesn't that stay O(S²)?"
Correct. At step $t$, the new token's $Q$ attends over all $t-1$ cached $K$ vectors: $(1, d_{head}) \times (d_{head}, t-1) = O(t \cdot d_{head})$ per head. Summed across all steps and heads: $O(S^2 \cdot d_{model})$ — still quadratic. KV caching eliminates the redundant projection cost. The attention over the growing context remains $O(S^2)$.
Masking = set certain attention scores to $-\infty$ before softmax so they get zero weight. Three distinct situations need it:
| Mask | Problem It Solves | Where Applied | Implementation |
|---|---|---|---|
| Padding mask | Batch sequences are padded to equal length — PAD tokens shouldn't influence attention | Attention scores | attention_mask (1=real, 0=pad) |
| Loss mask | PAD tokens shouldn't contribute to the loss | Cross-entropy loss | ignore_index=0 |
| Causal mask | Autoregressive models can't see future tokens during training | Attention scores | Upper triangular $-\infty$, built into model |
Bug 1 — No attention mask: Real tokens attend to PAD tokens, pulling garbage into their representations.
Bug 2 — No loss mask: Model penalised for predicting PAD token positions — meaningless noise.
"Is ignore_index=0 the position index or the token value?"
"Skip any position where the target tensor value equals 0." If your PAD token ID is 99, use ignore_index=99. Think of it as ignore_token_id.
Symptom of both bugs: model performs poorly on short sequences (more padding = more noise), well on long ones. Length-correlated degradation is the tell.
| # | What | Effect if Missing | Severity |
|---|---|---|---|
| 2 | zero_grad() | Gradients accumulate — wrong update every step | 🔴 Immediate failure |
| 4 | ignore_index | PAD tokens add noise to loss — silent degradation | 🟠 Silent |
| 3 | attention_mask | PAD pollutes representations — silent degradation | 🟠 Silent |
| 8 | model.train() after eval | Dropout/BN in wrong mode — conditional on having called eval() | 🟡 Conditional |
| 5 | grad clipping | Gradient explosion on long sequences | 🟡 Rare for transformers |
"Why do transformers need LR warmup? CNNs work fine without it."
Reason 1 — Adam's early estimates are unreliable. Adam maintains $m_t$ (EMA of gradients) and $v_t$ (EMA of squared gradients). At step 1, $m_1 = 0.1 \times g_1$ — based on one sample, highly noisy. A large LR multiplies this noise into a large bad update.
100k total steps, target LR $= 1\times10^{-4}$:
Steps 0–1000: LR increases linearly $0 \to 1\times10^{-4}$ (warmup — 1% of training).
Steps 1000–100k: cosine decay $1\times10^{-4} \to 1\times10^{-5}$.
The 1000-step warmup often makes the difference between stable and divergent early training.
Reason 2 — Transformers amplify instability. Multi-head attention + layer norm + residual connections across many layers means a bad early update propagates through all $L$ layers. CNNs are locally connected and more forgiving.
The 2017 "Attention is All You Need" transformer was a breakthrough — but it had real problems. Here's how each was addressed, in order. Each entry below will become its own deep-dive post.
Introduced self-attention, multi-head attention, positional encoding. Encoder-decoder architecture for translation. Replaced RNNs entirely for sequence-to-sequence tasks.
GPT: decoder-only, causal masking, autoregressive pretraining. BERT: encoder-only, masked language modelling, bidirectional context. Both showed that large-scale pretraining + fine-tuning transfers to nearly any NLP task.
First real attempt at extending context without quadratic memory blowup.
Enabled processing of long documents (legal, scientific papers) that were previously impossible.
RoPE is now the standard in most modern LLMs — LLaMA, Mistral, GPT-4 all use it.
Same O(S²) compute — but 5–20× faster in practice due to memory access patterns. Became the standard attention implementation. FlashAttention-2 (2023) and FlashAttention-3 (2024) pushed further.
GQA reduces KV cache by 8× vs multi-head attention with minimal quality loss. Now standard in production LLMs.
Competitive with transformers on language tasks at medium scale. Not yet clearly dominant — hybrid Mamba+attention models (Jamba) show promise. Active research area.
Combined with MoE and FP8 training, DeepSeek-V3 achieved GPT-4 level performance at a fraction of training cost. Caused significant industry attention.
Every innovation above is attacking one of three constraints: (1) O(S²) compute (sparse attention, linear attention, Mamba), (2) O(S²) memory (FlashAttention, KV caching, MQA/GQA, MLA), or (3) fixed context/positional generalisation (Transformer-XL, RoPE). Each post in this series will go deep on one of these rows.
"RNNs process sequentially and compress everything into a hidden state — causing vanishing gradients and preventing parallelisation.
Transformers replace this with attention: $\text{softmax}(QK^\top/\sqrt{d_{head}})V$. Every token directly queries every other. We scale by $\sqrt{d_{head}}$ because dot product variance grows as $d_{head}$, and large values saturate softmax — dividing restores unit variance.
Multi-head attention runs $n_{heads}$ independent attention operations in parallel, each specialising in different relationship types, using the same total parameters as single-head attention ($n_{heads} \times d_{head} = d_{model}$).
Attention is $O(S^2 \cdot d_{model})$ compute because $QK^\top$ is $(S, d_{head}) \times (d_{head}, S) = O(S^2 d_{head})$ per head, times $n_{heads}$ gives $O(S^2 d_{model})$. Memory is $O(S^2)$ because the score matrix is $(S, S)$ — $d_{head}$ was consumed by the dot product.
KV caching saves inference compute by storing previous tokens' $K$ and $V$ — they're deterministic functions of those tokens and never change. $Q$ is always the fresh new token's question, so it's never cached. Without caching, projection cost is $O(S^2 d_{model})$ total. With caching, $O(S \cdot d_{model}^2)$ — linear in $S$. Cache size is $2SLd_{model}$ numbers = $4SLd_{model}$ bytes in float16.
Three masking types: attention mask for padding, ignore_index for loss, causal mask for autoregressive models. Forgetting the loss mask causes silent length-correlated degradation. LR warmup stabilises early Adam estimates and prevents cascading instability in deep networks."
| Concept | One Line | Key Number |
|---|---|---|
| Attention formula | $\text{softmax}(QK^\top/\sqrt{d_{head}})V$ | — |
| Why $\sqrt{d_{head}}$ | Dot product std dev = $\sqrt{d_{head}}$; divide to restore unit variance | $d_{head}=64$ → divide by 8 |
| Multi-head | Same params ($n_{heads} \times d_{head} = d_{model}$), $n_{heads}$ specialised patterns | GPT-2: 12 heads × 64 = 768 |
| Matrix multiply cost | $(m,k)\times(k,p) = O(m\cdot k\cdot p)$ — multiply all three dims | Always |
| Attention compute | $O(S^2 \cdot d_{model})$ from $QK^\top$ being $(S,S)$ times $n_{heads}$ heads | Double $S$ → 4× FLOPs |
| Attention memory | $O(S^2)$ — score matrix is $(S,S)$, $d_{head}$ consumed by dot product | $S=32k$ → 49 GB/layer |
| KV cache what | Store K,V of prev tokens; Q is always fresh so never cached | — |
| KV cache saves | Projection cost: $O(S^2 d_{model}) \to O(S \cdot d_{model}^2)$ across all steps | $S=10k$ → ~13× faster projections |
| KV cache size | $2SLd_{model}$ numbers, $4SLd_{model}$ bytes float16 | LLaMA 70B, $S=4k$ → 10 GB |
| Padding mask | Add $-\infty$ to PAD positions before softmax | attention_mask |
| Loss mask | Skip PAD token IDs in cross-entropy | ignore_index=0 |
| Causal mask | Upper triangular $-\infty$ prevents future token leakage | GPT-style only |
Papers referenced: Attention Is All You Need (2017) · FlashAttention (2022) · GQA (2023) · Mamba (2023) · DeepSeek-V3 (2024)