Self-Attention: From Dot Products to What the Model Focuses On
Query, Key, Value matrices explained. Why multi-head attention sees different things at once, and what that means for long-range dependencies.
Self-attention is the operation that makes transformers work. Everything else — positional encodings, residual connections, layer norms — is scaffolding. If you understand self-attention deeply, you understand 80% of what a transformer is doing.
[Video: Andrej Karpathy — Let's build GPT from scratch (builds a full GPT in Python, attention included)]
The core question attention answers
For each token in a sequence, self-attention asks: which other tokens should I borrow information from, and how much? The answer is computed dynamically — it depends on the content of the tokens, not their positions. This is what makes attention so powerful: the same token can attend to completely different things depending on context.
In "The animal didn't cross the street because it was tired", what does "it" refer to? Self-attention resolves this — the word "it" attends strongly to "animal" rather than "street". This is coreference resolution, emergent from attention.
Query, Key, Value: the retrieval metaphor
The QKV formulation is a learned soft retrieval system. Think of it like a search engine: your Query is what you're searching for, the Keys are what each document is indexed under, and the Values are the actual document content returned.
- Query (Q): the current token asking "what do I need?"
- Key (K): every token broadcasting "here is what I contain"
- Value (V): every token's actual information, returned if selected
- Attention weight: how well a query matches a key — computed as a scaled dot product
import torch
import torch.nn.functional as F
def self_attention(X, W_Q, W_K, W_V, mask=None):
"""
X: (seq_len, d_model)
W_Q, W_K, W_V: (d_model, d_k)
"""
Q = X @ W_Q # (seq_len, d_k)
K = X @ W_K # (seq_len, d_k)
V = X @ W_V # (seq_len, d_v)
d_k = Q.shape[-1]
scores = Q @ K.T / d_k**0.5 # (seq_len, seq_len)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1) # rows sum to 1
return weights @ V # (seq_len, d_v)
Why scale by √d_k?
Without scaling, dot products grow with dimensionality. For d_k = 64, random vectors have dot products with expected value 0 and standard deviation √64 = 8. Large dot products push softmax into saturation — near-zero gradients, and the model stops learning. Dividing by √d_k keeps the variance at 1 regardless of dimension.
The attention matrix: what it reveals
The attention weight matrix is (seq_len × seq_len). Each row is a probability distribution over all positions — how much each output token draws from each input token. In trained models, these patterns are remarkably interpretable:
- Syntactic heads often form diagonal patterns (attending to the previous token) or induction patterns
- Coreference heads create off-diagonal spikes connecting pronouns to their antecedents
- Positional heads form banded patterns, attending to fixed relative distances
- Some heads in GPT-2 are "attention sinks" — they absorb attention from all tokens but contribute little
Multi-head: why one isn't enough
A single attention head can only capture one type of relationship at once. Multi-head attention runs H independent attention functions with different learned W_Q, W_K, W_V projections. Each head specialises in different syntactic, semantic, or positional relationships. The outputs are concatenated and projected back to d_model.
Anthropic's mechanistic interpretability research has identified individual attention heads in GPT-2 that implement specific algorithms: induction heads that complete patterns, name-mover heads that copy names, and backup heads that activate only when primary circuits fail.
Causal (masked) self-attention
In decoder models (GPT-style), tokens can only attend to previous tokens — not future ones. This is enforced by a causal mask: setting future positions to −∞ before softmax. Without this, the model would "cheat" during training by reading ahead.
Flash Attention — how modern systems make it fast
Standard attention is O(n²) in memory — for a 100K-token context, the attention matrix alone is 100K × 100K = 10 billion elements. Flash Attention (Dao et al., 2022) reorders the computation to avoid materialising this matrix in GPU HBM, using tiling to compute attention in SRAM. Result: 2–4× speedup and sub-quadratic memory usage. Flash Attention 2 and 3 push this further. All major inference frameworks use it by default.
You don't implement Flash Attention — PyTorch's scaled_dot_product_attention() calls it automatically when available. But understanding why it exists explains why long-context models became practical after 2022: it made attending over 100K+ tokens feasible on a single GPU.
Visualise attention patterns →: See live attention weight matrices for real text in the Concepts module.
→ Interactive: The Transformer Architecture module in Systems Lab has a token-level attention heatmap you can click through.
[Video: embedded video]
- Attention Is All You Need (Vaswani et al., 2017)
- The Illustrated GPT-2 — Jay Alammar
- Flash Attention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
Try it interactively
GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.
Open GenAI Systems Lab →