The Transformer FFN Block: What It Does, Why 4x, and Why SwiGLU Won
The FFN contains most of a transformer's parameters and acts as a key-value memory over learned concepts. ReLU vs SwiGLU vs GeGLU: the gating mechanism, the dead neuron problem, the parameter count math, and why every modern LLM switched to SwiGLU.
The transformer block has two sublayers: multi-head attention and a position-wise feed-forward network. Attention gets all the academic attention (pun intended). The FFN is treated as a box. But the FFN contains the majority of the model's parameters in large models — in GPT-3 175B, roughly 65% of parameters are in FFN layers. It does something distinct and specific. This post is about what it actually does.
The architecture
The FFN in the original transformer: FFN(x) = max(0, x·W1 + b1)·W2 + b2. Two linear layers with a ReLU in between. W1 has shape (d_model, d_ff) and W2 has shape (d_ff, d_model). The hidden dimension d_ff is typically 4 × d_model. For a 4096-dimensional model, d_ff = 16384. The FFN expands the representation by 4×, applies non-linearity, then projects back.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FFN_ReLU(nn.Module):
"""Original transformer FFN (Vaswani et al., 2017)."""
def __init__(self, d_model, d_ff=None):
super().__init__()
d_ff = d_ff or 4 * d_model
self.W1 = nn.Linear(d_model, d_ff)
self.W2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.W2(F.relu(self.W1(x)))
class FFN_SwiGLU(nn.Module):
"""
SwiGLU FFN (Shazeer 2020), used in LLaMA, Mistral, PaLM, Gemma.
Gated linear unit: output = (xW1) ⊙ swish(xW3) → xW2
Uses d_ff = 2/3 × 4 × d_model to keep parameter count similar.
"""
def __init__(self, d_model, d_ff=None):
super().__init__()
d_ff = d_ff or int(2/3 * 4 * d_model)
self.W1 = nn.Linear(d_model, d_ff, bias=False) # gate projection
self.W2 = nn.Linear(d_ff, d_model, bias=False) # down projection
self.W3 = nn.Linear(d_model, d_ff, bias=False) # up projection
def forward(self, x):
# Gate controls which "slots" of the expanded space pass through
gate = F.silu(self.W1(x)) # silu = x * sigmoid(x) = swish
up = self.W3(x)
return self.W2(gate * up) # element-wise gating
d_model = 512
ffn_relu = FFN_ReLU(d_model)
ffn_swiglu = FFN_SwiGLU(d_model)
params_relu = sum(p.numel() for p in ffn_relu.parameters())
params_swiglu = sum(p.numel() for p in ffn_swiglu.parameters())
print(f"ReLU FFN parameters: {params_relu:,} (d_ff = {4*d_model})")
print(f"SwiGLU FFN parameters: {params_swiglu:,} (d_ff = {int(2/3*4*d_model)}, 3 matrices)")
# Forward pass
x = torch.randn(4, 16, d_model) # batch=4, seq=16, d_model=512
print(f"
Input shape: {x.shape}")
print(f"ReLU FFN output: {ffn_relu(x).shape}")
print(f"SwiGLU FFN output: {ffn_swiglu(x).shape}")
# Dead neuron analysis for ReLU
with torch.no_grad():
activated = F.relu(ffn_relu.W1(x)) # (4, 16, d_ff)
dead_pct = (activated == 0).float().mean().item()
print(f"
ReLU: {dead_pct:.1%} of neurons are dead (zero) on this batch")
What the FFN actually computes
The attention sublayer handles where — which positions to pull information from. The FFN sublayer handles what — given the aggregated information from attention, what should this representation become? The 4× expansion acts as a bottleneck: the input is projected into a higher-dimensional space where the non-linearity can carve out complex decision boundaries, then projected back down. Think of it as: first write many competing hypotheses about what this token might mean given its context (expand), pick the best ones (non-linearity kills weak hypotheses), synthesise (compress).
Research has shown that FFN layers function as key-value memories (Geva et al., 2021). The rows of W1 are 'keys' that activate when the input matches a certain pattern. The corresponding rows of W2 are 'values' that fire when the key is activated. The FFN is literally implementing a lookup table over learned concepts.
SwiGLU: why almost everyone switched
SwiGLU replaces ReLU with a gating mechanism: output = (x·W1) ⊙ swish(x·W3), where swish(x) = x·sigmoid(x) is smooth and does not have the dead neuron problem. The gate (the SiLU/swish activation) learns which dimensions of the expanded representation to pass through versus suppress. This learnable gating is more expressive than a fixed threshold. LLaMA, Mistral, Qwen, Gemma, PaLM all use SwiGLU or a variant (GeGLU uses GELU instead of SiLU).
The parameter count stays similar despite adding a third matrix: SwiGLU uses d_ff = 2/3 × 4 × d_model (rounded to a multiple of 64 for GPU efficiency), so the third matrix (W3) is offset by the smaller d_ff. Practically, SwiGLU trains faster per loss unit and is now the default for any new model.
Visualise FFN activations: run a real LLaMA model with access to intermediate layers (via hooks in PyTorch). Collect the FFN hidden activations across a batch of sentences. Measure what fraction are near-zero — these are the 'dead' in ReLU, or near-suppressed by the SwiGLU gate. Neurons that are consistently zero across all inputs are learning nothing. This is a useful sanity check for fine-tuned models that may have overfit.
- GLU Variants Improve Transformer — Noam Shazeer (2020)
- Transformer Feed-Forward Layers Are Key-Value Memories — Geva et al. (2021)
Try it interactively
GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.
Open GenAI Systems Lab →