Deep Dive Series

Transformers: The Complete Story

Attention, Multi-Head, Complexity, KV Caching — Broken down and Explained from the basics

📐 Notation Reference — Every Variable Used In This Post

Defined once here. Refer back whenever a symbol appears and you're unsure what it means.

SymbolWhat It MeansTypical Value (GPT-2 Small)
$B$Batch size — number of sequences processed simultaneously32
$S$Sequence length — number of tokens in one sequence1024
$d_{model}$Model dimension — size of each token's embedding vector768
$n_{heads}$Number of attention heads12
$d_{head}$Dimension per head — always $d_{model} / n_{heads}$64
$L$Number of transformer layers (depth)12
$X$Input matrix — shape $(S, d_{model})$, all tokens stacked
$Q$Query matrix — shape $(S, d_{head})$ per head
$K$Key matrix — shape $(S, d_{head})$ per head
$V$Value matrix — shape $(S, d_{head})$ per head
$W_Q, W_K, W_V$Learned projection weight matrices — shape $(d_{model}, d_{model})$ each
$W_O$Output projection weight matrix — shape $(d_{model}, d_{model})$
$t$Current generation step (which token is being produced)
FLOPsFloating point operations — how we measure compute cost

Contents

Part I

Why Transformers?

1 The Problem with RNNs

Before transformers, sequence modelling meant RNNs and LSTMs. They process one token at a time, left to right, updating a hidden state $h_t$ at each step:

h_0 → h_1 → h_2 → h_3 → h_4 → h_5 ↑ ↑ ↑ ↑ ↑ ↑ "The" "cat" "sat" "on" "the" "mat" To understand "mat" using "The", information about "The" must survive 5 hidden state compressions. By step 5 it's a faint echo.
Problem 1 — Vanishing Gradients: Backpropagating through 500 timesteps means gradients shrink exponentially. Early tokens contribute almost nothing to learning.
Problem 2 — Sequential Processing: Step $t$ depends on step $t-1$. Can't parallelise. Slow on long sequences.
Problem 3 — Fixed Bottleneck: All information must pass through one $d_{model}$-dimensional vector. Lossy by design.
What if every token could directly look at every other token?

2 The Core Idea: Attention

One Line Intuition

Instead of compressing the whole sequence into a hidden state, let every token directly query every other token and decide how much to borrow from it.

For "The cat sat on the mat because it was tired" — when processing "it", attention scores every token directly:

Attention weights computed by "it": The cat sat on the mat because it was tired 0.02 0.71 0.05 0.02 0.03 0.08 0.04 0.01 0.02 0.02 ↑ "cat" gets 71% of attention weight. Coreference resolved directly. No compression needed. No forgetting.

The output for "it" = $0.71 \times \text{rep}(\text{"cat"}) + 0.08 \times \text{rep}(\text{"mat"}) + \ldots$ — a context-aware blend.

Part II

The Attention Mechanism

3 Q, K, V — What They Are

Every token produces three vectors via learned linear projections $W_Q, W_K, W_V$:

Q, K, V Intuition

VectorDatabase AnalogyRole in Attention
Query $Q$"What am I looking for?"Current token's search request
Key $K$"What do I contain?"Each token advertising its content
Value $V$"What do I actually give you?"Information to aggregate if selected
Resolving "it": Query from "it": "find me: animate subject noun, introduced earlier" Key from "cat": "I am: animate, subject, early position" → HIGH match Key from "mat": "I am: inanimate, location object" → low match Key from "sat": "I am: a verb" → low match match = Q_it · K_cat = high → attend more match = Q_it · K_mat = low → attend less Output = 0.71 × V_cat + 0.08 × V_mat + ... = mostly "cat"'s information

In matrix form (all $S$ tokens at once):

$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$ $$\text{where } X \in \mathbb{R}^{S \times d_{model}}, \quad W_Q, W_K, W_V \in \mathbb{R}^{d_{model} \times d_{model}}$$

4 Scaled Dot-Product Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_{head}}}\right) V$$
Step 1 — Raw scores: QKᵀ → shape (S, S)
$(S, d_{head}) \times (d_{head}, S) = (S, S)$. Entry $(i, j)$ = dot product of token $i$'s query and token $j$'s key = how much token $i$ wants to attend to $j$.
Step 2 — Scale by √d_head
Divide every score by $\sqrt{d_{head}}$. Prevents large values from breaking softmax. Chapter 5 explains why with actual numbers.
Step 3 — Softmax (row-wise)
Each row (token $i$'s scores over all $S$ tokens) becomes a probability distribution summing to 1. These are the attention weights.
Step 4 — Weighted sum: weights × V → shape (S, d_head)
$(S, S) \times (S, d_{head}) = (S, d_{head})$. Each token gets a weighted mix of all value vectors. Token representations now contain context.

5 Why Scale by √d_head? (With Numbers)

🤔 Question

"Why √d_head? Why not d_head? Why not nothing?"

Assume $Q$ and $K$ are initialised from $\mathcal{N}(0,1)$. Each element has mean 0, variance 1. The dot product $Q_i \cdot K_j = \sum_{l=1}^{d_{head}} Q_{il} \cdot K_{jl}$ is a sum of $d_{head}$ products.

📊 Worked Example — Variance of one dot product

If $a \sim \mathcal{N}(0,1)$ and $b \sim \mathcal{N}(0,1)$ independently: $\text{Var}(ab) = 1$.

Sum of $d_{head}$ such independent terms has variance $= d_{head}$. Standard deviation $= \sqrt{d_{head}}$.

With $d_{head}=64$: std dev = 8, so dot products are typically in $[-24, +24]$.

Feed $[-24, +24]$ into softmax:

softmax([24, 23, 0, -1]): ≈ [0.73, 0.27, 0.00, 0.00] One token absorbs nearly all attention weight. Softmax gradient ≈ 0. Model stops learning.

Divide by $\sqrt{d_{head}}$ first:

$$\text{Var}\!\left(\frac{Q_i \cdot K_j}{\sqrt{d_{head}}}\right) = \frac{d_{head}}{d_{head}} = 1$$
📊 After Scaling — d_head=64, divide by √64=8

Scores in $[-24,+24]$ become $[-3,+3]$.

softmax([3, 2, 0, -1]) = [0.61, 0.23, 0.11, 0.04] — smooth distribution, healthy gradients.

$\sqrt{d_{head}}$ is exactly the standard deviation of the unscaled dot product — dividing by it restores unit variance.

Part III

Multi-Head Attention

6 Why One Head Isn't Enough

🤔 Question

"Why split into n_heads heads? Why not just use one big attention with full d_model?"

"The cat sat on the mat because it was tired" Relationships present simultaneously: Coreference: "it" → "cat" Syntactic: "cat" is subject of "sat" Positional: "on" relates "sat" to "mat" Semantic: "tired" is a state of "cat" One head: one weight matrix W_Q captures a single mixture of all types. Learns a compromise — mediocre at everything. n_heads=12: each head independently learns one type. Head 1 → coreference Head 2 → syntactic roles Head 3 → positional ... Concatenate all heads → rich, multi-perspective representation.
Key Insight — Same Parameters, Better Structure

Splitting $d_{model}$ into $n_{heads}$ heads of $d_{head} = d_{model}/n_{heads}$ each uses the same number of parameters. You're not adding capacity — you're restructuring it to encourage specialisation.

7 The Full Dimension Flow

Tracing exact shapes through GPT-2 Small: $B=1$, $S=1024$, $d_{model}=768$, $n_{heads}=12$, $d_{head}=64$.

Input X: (1, 1024, 768) = (B, S, d_model) Project W_Q, W_K, W_V — each (768, 768): Q = X @ W_Q: (1, 1024, 768) K = X @ W_K: (1, 1024, 768) V = X @ W_V: (1, 1024, 768) Reshape to split into heads: Q: (1, 1024, 12, 64) = (B, S, n_heads, d_head) Transpose for batched matmul: Q: (1, 12, 1024, 64) = (B, n_heads, S, d_head) K, V: same Attention scores Q @ K.T: (1, 12, 1024, 1024) = (B, n_heads, S, S) ↑ THIS is the S² matrix — 12.5M numbers per batch Divide by √64=8, softmax row-wise. Weighted sum @ V: (1, 12, 1024, 64) = (B, n_heads, S, d_head) Transpose + reshape (concatenate heads): (1, 1024, 768) = (B, S, d_model) Output projection W_O (768×768): (1, 1024, 768) = same shape as input ✓

8 Weight Matrices: Still d_model × d_model

🤔 Question

"If there are 12 heads, are there 12 separate W_Q matrices or one shared one?"

✓ One matrix (d_model × d_model) that implicitly contains all heads

$W_Q$ is $(768, 768)$. After the projection we reshape to $(S, n_{heads}, d_{head})$ to split by head. Conceptually it's $n_{heads}$ stacked projections of $(768, 64)$. One matrix for GPU efficiency; split after multiplication.

📊 Parameter Count — One Attention Layer, GPT-2 Small

$d_{model}=768$, $n_{heads}=12$, $d_{head}=64$.

$W_Q$: $768 \times 768 = 589{,}824$  |  $W_K$: same  |  $W_V$: same  |  $W_O$: same

Total per layer: $\approx 2.36M$ parameters — identical for 1 head or 12 heads.

Part IV

Complexity — Where O(S²·d_model) Comes From

9 How Matrix Multiply Cost Works (From Scratch)

🤔 Starting from scratch

"I've heard O(S²·d_model) but I don't know how to count the cost of multiplying two matrices."

Dot product of two length-$k$ vectors: $k$ multiplications + $(k-1)$ additions $\approx 2k$ operations.

📊 Dot Product Cost

$[1,2,3]\cdot[4,5,6] = 4+10+18=32$  →  3 multiplications + 2 additions = 5 ops $\approx 2k$ for $k=3$.

Matrix multiply $(A: m \times k)$ by $(B: k \times p)$: result is $(m \times p)$. Each of the $m \times p$ entries is a length-$k$ dot product.

$$\text{Cost of } (m,k) \times (k,p) = O(m \cdot k \cdot p)$$
📊 Matrix Multiply Cost

$(4\times3)\times(3\times5)$ → result $(4,5)$. Number of dot products: $4\times5=20$. Each length 3.

Total: $4\times5\times3=60$ ops = $O(m\cdot k\cdot p)$.

Rule: multiply all three dimensions of the two matrices together.

10 Compute: Why O(S²·d_model)

Operation 1 — Linear Projections (Q, K, V, O)

$X W_Q$: $(S, d_{model}) \times (d_{model}, d_{model}) \to (S, d_{model})$. Cost $= O(S \cdot d_{model}^2)$. Four projections total: $O(S \cdot d_{model}^2)$.

Operation 2 — Attention Scores QKᵀ

Per head: $Q$ is $(S, d_{head})$, $K^\top$ is $(d_{head}, S)$. Result: $(S, S)$.

Cost per head $= O(S \cdot d_{head} \cdot S) = O(S^2 d_{head})$.

All $n_{heads}$ heads: $O(n_{heads} \cdot S^2 d_{head}) = O(S^2 \cdot n_{heads} \cdot d_{head}) = O(S^2 \cdot d_{model})$ since $n_{heads} \cdot d_{head} = d_{model}$.

📊 QKᵀ Cost — GPT-2 Small

$S=1024$, $d_{head}=64$, $n_{heads}=12$.

Per head: $1024 \times 1024 \times 64 = 67M$ FLOPs.

All 12 heads: $12 \times 67M = 805M$ FLOPs for attention scores alone.

Projection cost: $1024 \times 768 \times 768 = 603M$ FLOPs. Both are at the same order of magnitude.

Operation 3 — Weighted Sum (weights × V)

$(S, S) \times (S, d_{head})$ per head. Cost $= O(S^2 d_{head})$ per head $= O(S^2 d_{model})$ total. Same as $QK^\top$.

Total

$$\text{Total Compute} = O(S \cdot d_{model}^2) + O(S^2 \cdot d_{model})$$
Which term dominates? S ≪ d_model (short sequences, large model): S·d_model² dominates → projections bottleneck S ≫ d_model (very long sequences): S²·d_model dominates → attention scores bottleneck S ~ d_model (typical): Both matter. Written as O(S²·d_model).
Why the S² Hurts

Double the sequence length → 4× the compute and memory for attention. This is why 100k-token contexts are expensive, and why FlashAttention, sparse attention, and linear attention were invented.

11 Memory: Why O(S²)

🤔 Question

"Compute is O(S²·d_model). Why is memory only O(S²)? Shouldn't memory scale with d_model too?"

Compute counts operations performed. Memory counts values that must be stored simultaneously.

The attention score matrix has shape $(B, n_{heads}, S, S)$. That's $B \times n_{heads} \times S^2$ numbers. The $d_{head}$ dimension was consumed by the dot product — it doesn't appear in the output matrix. So storing the score matrix costs $O(S^2)$ per batch element per head.

📊 Attention Score Memory — GPT-2 Small

$S=1024$, $n_{heads}=12$, float32 (4 bytes each).

Score matrix: $12 \times 1024 \times 1024 = 12{,}582{,}912$ floats $\approx$ 48 MB per layer.

During training with backprop ($L=12$ layers): $12 \times 48 = 576$ MB just for attention scores.

📊 Why Long Contexts Hit a Wall

Score matrix scales as $S^2$.

$S=1024$ → 48 MB/layer.   $S=4096$ → 768 MB/layer.   $S=32{,}768$ → 49 GB/layer.

At $S=32k$, storing attention scores is infeasible on most GPUs. This is the exact problem FlashAttention solves by recomputing scores on the fly rather than storing them.

Part V

KV Caching — Full Deep Dive

12 The Problem KV Caching Solves

KV caching is an inference-time optimisation only — it doesn't apply during training.

During autoregressive generation, tokens are produced one at a time:

Step 1: Input = ["The", "cat"] → generate "sat" Step 2: Input = ["The", "cat", "sat"] → generate "on" Step 3: Input = ["The", "cat", "sat", "on"]→ generate "the" ...

At each step, in every transformer layer, the model computes $K$ and $V$ for every token in the current input.

📊 Redundant Computation Without Caching

Generating the 100th token. Context = 99 previous tokens.

Model computes $K$ and $V$ for all 99 tokens (including "The", "cat", etc.).

Generating the 101st token. Context = 100 tokens.

Model computes $K$ and $V$ for all 100 tokens — including the same 99 from before.

$K$ of "The" at step 100 = $K$ of "The" at step 1 = identical. Recomputed for no reason.

For a 1000-token generation: token 1000 triggers recomputation of K,V for 999 prior tokens. Total redundant projection operations: $1+2+\ldots+999 \approx O(S^2)$.

13 Why K and V but not Q?

🤔 Question

"Why cache K and V specifically? What's special about them vs Q?"

Generating token t (the NEW token): Q of new token: = "What does THIS token need from the context?" → Computed from the new token's embedding. → DIFFERENT every step — new token, new Q. → Nothing to cache. Always fresh. K of token 1 ("The"): = "What does 'The' contain?" = X_The @ W_K → Computed from 'The's embedding, which never changes. → K of "The" at step 100 = K of "The" at step 1. IDENTICAL. → CACHE IT. V of token 1 ("The"): = "What does 'The' contribute if attended to?" = X_The @ W_V → Same logic. Cache it. At step t: Retrieve K,V for tokens 1..t-1 from cache (free) Compute K,V only for new token t (1 vector) Compute Q only for new token t (1 vector) Run attention: Q_t over all K_1..K_t (cheap) Append K_t, V_t to cache (for next step)
Key Insight

$K$ and $V$ are deterministic functions of their token — same token always produces the same $K$ and $V$. $Q$ is the current token's fresh question to the world — always new. Cache the answers, not the questions.

14 Cache Size: Where 2·S·L·d_model Comes From

🤔 Question

"I've seen KV cache size quoted as 2·S·L·d_model (roughly). Where does each term come from?"

Term: d_head
One $K$ vector for one token in one head has $d_{head}$ numbers. One $V$ vector also has $d_{head}$. Per token, per head: $2 \times d_{head}$ numbers.
Term: n_heads
There are $n_{heads}$ heads. Each has its own $K$ and $V$. Per token, per layer: $2 \times n_{heads} \times d_{head} = 2 \times d_{model}$ numbers (since $n_{heads} \times d_{head} = d_{model}$).
Term: L
There are $L$ transformer layers. Each layer has its own $K$ and $V$ cache. Per token: $2 \times L \times d_{model}$ numbers.
Term: S
We cache up to $S$ tokens (full context). Total: $2 \times S \times L \times d_{model}$ numbers.
$$\text{KV Cache} = 2 \cdot S \cdot L \cdot d_{model} \text{ numbers}$$ $$\text{In bytes (float16, 2 bytes per number):} \quad 4 \cdot S \cdot L \cdot d_{model} \text{ bytes}$$
📊 GPT-2 Small KV Cache

$S=1024$, $L=12$, $d_{model}=768$, float16.

Numbers: $2 \times 1024 \times 12 \times 768 = 18{,}874{,}368$

Bytes: $\times 2 = 37{,}748{,}736 \approx$ 36 MB. Tiny.

📊 LLaMA 70B KV Cache — Why This Matters

$S=4096$, $L=80$, $d_{model}=8192$, float16.

Bytes: $4 \times 4096 \times 80 \times 8192 = 10{,}737{,}418{,}240 \approx$ 10 GB just for KV cache.

Model weights: ~130 GB additional. For $S=32{,}768$ (long context): KV cache alone → ~80 GB.

15 Does It Actually Save That Much? (With Numbers)

🤔 Question

"K and V are just two matrix multiplies. How much can reusing them actually save?"

At generation step $t$, the input context has $t-1$ tokens.

Without KV cache: compute $K$ and $V$ for all $t-1$ tokens per layer. Each is $(t-1, d_{model}) \times (d_{model}, d_{model})$: cost $= O((t-1) \cdot d_{model}^2)$ per layer.

With KV cache: compute $K$ and $V$ only for the new token. That's $(1, d_{model}) \times (d_{model}, d_{model})$: cost $= O(d_{model}^2)$ per layer — independent of $t$.

📊 Operations Saved at Step t=500, GPT-2 Small

$t=500$, $d_{model}=768$, $L=12$.

Without cache — K,V projections all layers: $499 \times 768 \times 768 \times 12 \times 2 \approx 5.6B$ FLOPs

With cache — K,V for 1 new token all layers: $1 \times 768 \times 768 \times 12 \times 2 \approx 14M$ FLOPs

Savings: ~400× fewer FLOPs for K,V projections at step 500.

Total savings across all $S$ generation steps:

Total K,V projection FLOPs across generating S tokens: Without KV cache: Step 1: K,V for 1 token → O(d_model²) Step 2: K,V for 2 tokens → O(2·d_model²) ... Step S: K,V for S tokens → O(S·d_model²) Total: (1+2+...+S)·d_model² = O(S²·d_model) ← quadratic in S With KV cache: Every step: K,V for 1 token → O(d_model²) Total: S·d_model² = O(S·d_model²) ← linear in S Speedup ratio: O(S²·d_model) / O(S·d_model²) = O(S/d_model) At S=10000, d_model=768: ~13× faster projections.
🤔 Follow-up

"But we still run attention over the full cached K and V every step. Doesn't that stay O(S²)?"

✓ Yes — and this is the remaining bottleneck

Correct. At step $t$, the new token's $Q$ attends over all $t-1$ cached $K$ vectors: $(1, d_{head}) \times (d_{head}, t-1) = O(t \cdot d_{head})$ per head. Summed across all steps and heads: $O(S^2 \cdot d_{model})$ — still quadratic. KV caching eliminates the redundant projection cost. The attention over the growing context remains $O(S^2)$.

Part VI

Masking

16 The Three Types of Masking

Masking = set certain attention scores to $-\infty$ before softmax so they get zero weight. Three distinct situations need it:

MaskProblem It SolvesWhere AppliedImplementation
Padding maskBatch sequences are padded to equal length — PAD tokens shouldn't influence attentionAttention scoresattention_mask (1=real, 0=pad)
Loss maskPAD tokens shouldn't contribute to the lossCross-entropy lossignore_index=0
Causal maskAutoregressive models can't see future tokens during trainingAttention scoresUpper triangular $-\infty$, built into model

17 The Padding Bug

"hello world" → [5, 3, 0, 0, 0] 2 real + 3 PADs "good morning" → [2, 7, 1, 4, 0] 4 real + 1 PAD "I love pandas" → [6, 3, 9, 2, 1] 5 real + 0 PADs

Bug 1 — No attention mask: Real tokens attend to PAD tokens, pulling garbage into their representations.

# Broken ❌ — real tokens attend to PAD tokens output = model(x) # Fixed ✅ — PAD positions get -∞ before softmax → 0 weight output = model(x, attention_mask=mask) # mask: 1=real, 0=pad

Bug 2 — No loss mask: Model penalised for predicting PAD token positions — meaningless noise.

# Broken ❌ loss = nn.CrossEntropyLoss()(logits, targets) # Fixed ✅ loss = nn.CrossEntropyLoss(ignore_index=0)(logits, targets)
🤔 Question

"Is ignore_index=0 the position index or the token value?"

✓ Token value — not position

"Skip any position where the target tensor value equals 0." If your PAD token ID is 99, use ignore_index=99. Think of it as ignore_token_id.

Symptom of both bugs: model performs poorly on short sequences (more padding = more noise), well on long ones. Length-correlated degradation is the tell.

Part VII

Training

18 The Correct Training Loop

model = TransformerModel() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000) for epoch in range(num_epochs): # ── Training ────────────────────────────────────────────────────── model.train() # (1) train mode for x, y, mask in train_loader: optimizer.zero_grad() # (2) clear gradients output = model(x, attention_mask=mask) # (3) pass padding mask loss = nn.CrossEntropyLoss(ignore_index=0)(output, y) # (4) mask loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # (5) clip gradients optimizer.step() scheduler.step() # (6) update LR # ── Validation ──────────────────────────────────────────────────── model.eval() # (7) eval mode with torch.no_grad(): for x, y, mask in val_loader: output = model(x, attention_mask=mask) val_loss = nn.CrossEntropyLoss(ignore_index=0)(output, y) model.train() # (8) switch back!
#WhatEffect if MissingSeverity
2zero_grad()Gradients accumulate — wrong update every step🔴 Immediate failure
4ignore_indexPAD tokens add noise to loss — silent degradation🟠 Silent
3attention_maskPAD pollutes representations — silent degradation🟠 Silent
8model.train() after evalDropout/BN in wrong mode — conditional on having called eval()🟡 Conditional
5grad clippingGradient explosion on long sequences🟡 Rare for transformers

19 Why Warmup LR for Transformers?

🤔 Question

"Why do transformers need LR warmup? CNNs work fine without it."

Reason 1 — Adam's early estimates are unreliable. Adam maintains $m_t$ (EMA of gradients) and $v_t$ (EMA of squared gradients). At step 1, $m_1 = 0.1 \times g_1$ — based on one sample, highly noisy. A large LR multiplies this noise into a large bad update.

📊 Typical Warmup Schedule

100k total steps, target LR $= 1\times10^{-4}$:

Steps 0–1000: LR increases linearly $0 \to 1\times10^{-4}$ (warmup — 1% of training).

Steps 1000–100k: cosine decay $1\times10^{-4} \to 1\times10^{-5}$.

The 1000-step warmup often makes the difference between stable and divergent early training.

Reason 2 — Transformers amplify instability. Multi-head attention + layer norm + residual connections across many layers means a bad early update propagates through all $L$ layers. CNNs are locally connected and more forgiving.

Part VIII

What Came After — The Timeline

20 Every Problem the Vanilla Transformer Had, and Who Fixed It

The 2017 "Attention is All You Need" transformer was a breakthrough — but it had real problems. Here's how each was addressed, in order. Each entry below will become its own deep-dive post.

2017

Vanilla Transformer

Introduced self-attention, multi-head attention, positional encoding. Encoder-decoder architecture for translation. Replaced RNNs entirely for sequence-to-sequence tasks.

Problems: O(S²) memory, fixed positional encoding doesn't generalise, no way to run long contexts.
2018

GPT-1 & BERT — The Pretraining Era Begins

GPT: decoder-only, causal masking, autoregressive pretraining. BERT: encoder-only, masked language modelling, bidirectional context. Both showed that large-scale pretraining + fine-tuning transfers to nearly any NLP task.

Problem: Still O(S²), fixed context window (512 for BERT, 1024 for GPT), expensive to train.
2019

Transformer-XL — Breaking the Fixed Context Window

Problem: Transformer processes each context window independently. No memory across windows. Long-range dependencies get truncated.
Fix: Segment-level recurrence. Cache hidden states from previous segments; attend to them in the current segment. Also introduced relative positional encodings.

First real attempt at extending context without quadratic memory blowup.

2020

Longformer & BigBird — Sparse Attention

Problem: Full attention is O(S²). At S=4096, the attention matrix alone is 64GB. Infeasible for long documents.
Fix: Replace full attention with sparse patterns. Longformer: each token attends to a local window + a few global tokens (e.g. [CLS]). BigBird: local + global + random attention. Reduces to O(S) or O(S·window_size).

Enabled processing of long documents (legal, scientific papers) that were previously impossible.

2021

RoPE — Rotary Positional Embeddings

Problem: Absolute positional encodings (sinusoidal or learned) don't generalise to sequence lengths longer than seen during training.
Fix: Encode position as a rotation in the Q and K vectors. Relative position between tokens is naturally captured in the dot product. Generalises to longer sequences than training length.

RoPE is now the standard in most modern LLMs — LLaMA, Mistral, GPT-4 all use it.

2022

FlashAttention — Fixing the Memory Wall

Problem: The $(S, S)$ attention score matrix must be written to and read from GPU HBM (slow memory). At S=4096 this is 64MB per layer per batch — a massive memory bottleneck even if compute is fine.
Fix: Tiled computation. Split Q, K, V into blocks; compute attention scores, softmax, and weighted sum block-by-block within fast SRAM (on-chip memory). Never materialise the full $(S, S)$ matrix in HBM. Recompute during backprop instead of storing.

Same O(S²) compute — but 5–20× faster in practice due to memory access patterns. Became the standard attention implementation. FlashAttention-2 (2023) and FlashAttention-3 (2024) pushed further.

2022

Multi-Query Attention (MQA) & Grouped-Query Attention (GQA)

Problem: KV cache grows as $2 \cdot S \cdot L \cdot d_{model}$ bytes — at 10GB+ for large models, inference is memory-bound. Can't serve many users simultaneously.
MQA: All $n_{heads}$ query heads share a single K and V head. KV cache shrinks by $n_{heads}\times$. GQA (compromise): group $n_{heads}$ into $G$ groups; each group shares one K,V. LLaMA 2/3 uses GQA with $G=8$.

GQA reduces KV cache by 8× vs multi-head attention with minimal quality loss. Now standard in production LLMs.

2023

Sliding Window Attention & Mixture of Experts — Mistral

Problem: Full attention over long contexts is expensive even with FlashAttention. Most tokens don't need to attend to tokens far away.
Sliding window attention: each token attends only to the $W$ most recent tokens (e.g. $W=4096$). Combined with GQA and RoPE. Mixtral adds sparse MoE layers — only 2 of 8 expert FFN layers active per token, cutting compute while keeping parameters high.
2024

Linear Attention & State Space Models — Mamba

Problem: Attention is fundamentally O(S²). Even with all the tricks, processing 1M-token contexts remains expensive. Can we get rid of the quadratic entirely?
Mamba: replaces attention with selective state space models (SSMs). Key innovation: input-dependent (selective) state transitions. Achieves O(S) compute and O(1) memory per step during inference. No attention matrix at all.

Competitive with transformers on language tasks at medium scale. Not yet clearly dominant — hybrid Mamba+attention models (Jamba) show promise. Active research area.

2025

DeepSeek-V3 & Multi-Head Latent Attention (MLA)

Problem: Even GQA has a large KV cache. MQA sacrifices quality. Can we get small KV cache without quality loss?
Multi-Head Latent Attention (MLA): compress K and V into a low-rank latent vector $c \in \mathbb{R}^{d_c}$ where $d_c \ll d_{model}$. Cache only $c$ per token instead of full K,V. Reconstruct K,V from $c$ at attention time via learned up-projection. KV cache shrinks by $(2 \cdot n_{heads} \cdot d_{head}) / d_c \approx 5$–$10\times$ vs MHA with no quality loss.

Combined with MoE and FP8 training, DeepSeek-V3 achieved GPT-4 level performance at a fraction of training cost. Caused significant industry attention.

The Through-Line

Every innovation above is attacking one of three constraints: (1) O(S²) compute (sparse attention, linear attention, Mamba), (2) O(S²) memory (FlashAttention, KV caching, MQA/GQA, MLA), or (3) fixed context/positional generalisation (Transformer-XL, RoPE). Each post in this series will go deep on one of these rows.

21 Interview Summary

🎯 The Full Narrative

"RNNs process sequentially and compress everything into a hidden state — causing vanishing gradients and preventing parallelisation.

Transformers replace this with attention: $\text{softmax}(QK^\top/\sqrt{d_{head}})V$. Every token directly queries every other. We scale by $\sqrt{d_{head}}$ because dot product variance grows as $d_{head}$, and large values saturate softmax — dividing restores unit variance.

Multi-head attention runs $n_{heads}$ independent attention operations in parallel, each specialising in different relationship types, using the same total parameters as single-head attention ($n_{heads} \times d_{head} = d_{model}$).

Attention is $O(S^2 \cdot d_{model})$ compute because $QK^\top$ is $(S, d_{head}) \times (d_{head}, S) = O(S^2 d_{head})$ per head, times $n_{heads}$ gives $O(S^2 d_{model})$. Memory is $O(S^2)$ because the score matrix is $(S, S)$ — $d_{head}$ was consumed by the dot product.

KV caching saves inference compute by storing previous tokens' $K$ and $V$ — they're deterministic functions of those tokens and never change. $Q$ is always the fresh new token's question, so it's never cached. Without caching, projection cost is $O(S^2 d_{model})$ total. With caching, $O(S \cdot d_{model}^2)$ — linear in $S$. Cache size is $2SLd_{model}$ numbers = $4SLd_{model}$ bytes in float16.

Three masking types: attention mask for padding, ignore_index for loss, causal mask for autoregressive models. Forgetting the loss mask causes silent length-correlated degradation. LR warmup stabilises early Adam estimates and prevents cascading instability in deep networks."

TL;DR Cheatsheet
ConceptOne LineKey Number
Attention formula$\text{softmax}(QK^\top/\sqrt{d_{head}})V$
Why $\sqrt{d_{head}}$Dot product std dev = $\sqrt{d_{head}}$; divide to restore unit variance$d_{head}=64$ → divide by 8
Multi-headSame params ($n_{heads} \times d_{head} = d_{model}$), $n_{heads}$ specialised patternsGPT-2: 12 heads × 64 = 768
Matrix multiply cost$(m,k)\times(k,p) = O(m\cdot k\cdot p)$ — multiply all three dimsAlways
Attention compute$O(S^2 \cdot d_{model})$ from $QK^\top$ being $(S,S)$ times $n_{heads}$ headsDouble $S$ → 4× FLOPs
Attention memory$O(S^2)$ — score matrix is $(S,S)$, $d_{head}$ consumed by dot product$S=32k$ → 49 GB/layer
KV cache whatStore K,V of prev tokens; Q is always fresh so never cached
KV cache savesProjection cost: $O(S^2 d_{model}) \to O(S \cdot d_{model}^2)$ across all steps$S=10k$ → ~13× faster projections
KV cache size$2SLd_{model}$ numbers, $4SLd_{model}$ bytes float16LLaMA 70B, $S=4k$ → 10 GB
Padding maskAdd $-\infty$ to PAD positions before softmaxattention_mask
Loss maskSkip PAD token IDs in cross-entropyignore_index=0
Causal maskUpper triangular $-\infty$ prevents future token leakageGPT-style only

Papers referenced: Attention Is All You Need (2017) · FlashAttention (2022) · GQA (2023) · Mamba (2023) · DeepSeek-V3 (2024)