MHA vs MQA vs GQA: The Memory Math Behind Every Modern Model
Multi-Head Attention uses one K/V pair per head. Multi-Query uses one for all. Grouped Query splits the difference. The exact KV cache memory formula, a production comparison across LLaMA-3/Mistral/Falcon, and why GQA became the standard for every model released after mid-2023.
**Prerequisite: Step 2 (Attention from Scratch).** After this post you'll understand why LLMs split attention across multiple heads, what MQA and GQA trade away for speed, and why this architectural choice directly determines inference cost and KV cache size.
LLaMA-3, Mistral, Gemma, and every other modern open-weight model specifies a number of key-value heads that is different from — and smaller than — the number of query heads. This is Grouped Query Attention. Understanding why it exists requires understanding the memory arithmetic of the KV cache and exactly how Multi-Head Attention was changed to reduce it.
Multi-Head Attention: the baseline
In standard Multi-Head Attention (MHA), you have h heads. Each head has its own Q, K, V projection matrices. During generation, you cache the K and V tensors for every layer and every head. If you have 32 heads and 32 layers, you store 32×32 = 1,024 key and value tensors per sequence. Memory grows linearly with sequence length, number of layers, and number of heads.
import math
def kv_cache_bytes(n_layers, n_kv_heads, d_head, seq_len, dtype_bytes=2):
"""Memory for the KV cache (K and V, both stored)."""
return 2 * n_layers * n_kv_heads * d_head * seq_len * dtype_bytes
# LLaMA-3-8B configuration
n_layers = 32
d_head = 128
print("KV cache memory per request at different context lengths:
")
print(f"{'Model':30s} {'n_q':>5} {'n_kv':>5} {'4k ctx':>12} {'32k ctx':>12} {'128k ctx':>12}")
print("-" * 80)
configs = [
("LLaMA-2-7B (MHA)", 32, 32, 32, d_head),
("LLaMA-3-8B (GQA-8)", 32, 32, 8, d_head),
("Mistral-7B (GQA-8)", 32, 32, 8, d_head),
("LLaMA-3-70B (GQA-8)", 80, 64, 8, 128),
("Falcon-7B (MQA)", 32, 71, 1, 64),
]
for name, layers, n_q, n_kv, dh in configs:
b4k = kv_cache_bytes(layers, n_kv, dh, 4_096) / 1e9
b32k = kv_cache_bytes(layers, n_kv, dh, 32_768) / 1e9
b128k = kv_cache_bytes(layers, n_kv, dh, 131_072) / 1e9
print(f"{name:30s} {n_q:>5} {n_kv:>5} {b4k:>11.2f}G {b32k:>11.2f}G {b128k:>11.2f}G")
print("
Conclusion: GQA-8 reduces KV cache by 4x vs MHA with minimal quality loss.")
print("At 128k context, MHA LLaMA-2-7B would use 64GB — full A100. GQA uses 16GB.")
Multi-Query Attention: one K/V head for all Q heads
Multi-Query Attention (MQA, Shazeer 2019) uses a single K and V head shared across all query heads. For 32 query heads but 1 K/V head, the KV cache is 32× smaller. Quality degrades somewhat — each head is querying the same key and value representations, which limits the diversity of things different heads can attend to. MQA was adopted by Falcon and early PaLM variants where inference cost dominated quality concerns.
Grouped Query Attention: the middle ground that won
GQA (Ainslie et al., 2023) groups query heads into G groups, where each group shares one K/V head. With 32 query heads and G=8 groups, you have 4 query heads per K/V head — and 8 K/V heads total. KV cache is 4× smaller than MHA. Empirically, GQA-8 matches MHA quality on most benchmarks while providing nearly the memory savings of MQA. This is why every major model released after mid-2023 uses GQA: LLaMA-3, Mistral, Qwen, Gemma.
The attention computation with GQA
For each group of 4 query heads sharing 1 K/V head: the 4 Q projections are different (each head learns to query differently), but they all key into the same K and value-project through the same V. This means different query heads can attend to different positions (because their Q projections differ) but they read the same information from those positions (because V is shared). The diversity is in what you attend to, not in what you read. This is the quality trade-off GQA makes.
Calculate: for your production model, what is the KV cache footprint at your target context length? Use the formula: 2 × n_layers × n_kv_heads × d_head × seq_len × 2 (FP16). For LLaMA-3-8B at 8k context: 2×32×8×128×8192×2 = 536MB per request. With 80GB A100 and ~16GB model weights: (80-16)GB / 0.54GB ≈ 118 concurrent requests at this context length. This calculation determines your serving infrastructure.
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
- Fast Transformer Decoding: One Write-Head is All You Need (MQA)
Try it interactively
GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.
Open GenAI Systems Lab →