GenAI Systems Lab Open interactive version →
Foundations & Architecture 10 min read

The Transformer FFN Block: What It Does, Why 4x, and Why SwiGLU Won

The FFN contains most of a transformer's parameters and acts as a key-value memory over learned concepts. ReLU vs SwiGLU vs GeGLU: the gating mechanism, the dead neuron problem, the parameter count math, and why every modern LLM switched to SwiGLU.

The transformer block has two sublayers: multi-head attention and a position-wise feed-forward network. Attention gets all the academic attention (pun intended). The FFN is treated as a box. But the FFN contains the majority of the model's parameters in large models — in GPT-3 175B, roughly 65% of parameters are in FFN layers. It does something distinct and specific. This post is about what it actually does.

The architecture

The FFN in the original transformer: FFN(x) = max(0, x·W1 + b1)·W2 + b2. Two linear layers with a ReLU in between. W1 has shape (d_model, d_ff) and W2 has shape (d_ff, d_model). The hidden dimension d_ff is typically 4 × d_model. For a 4096-dimensional model, d_ff = 16384. The FFN expands the representation by 4×, applies non-linearity, then projects back.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FFN_ReLU(nn.Module):
    """Original transformer FFN (Vaswani et al., 2017)."""
    def __init__(self, d_model, d_ff=None):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.W2(F.relu(self.W1(x)))


class FFN_SwiGLU(nn.Module):
    """
    SwiGLU FFN (Shazeer 2020), used in LLaMA, Mistral, PaLM, Gemma.
    Gated linear unit: output = (xW1) ⊙ swish(xW3) → xW2
    Uses d_ff = 2/3 × 4 × d_model to keep parameter count similar.
    """
    def __init__(self, d_model, d_ff=None):
        super().__init__()
        d_ff = d_ff or int(2/3 * 4 * d_model)
        self.W1 = nn.Linear(d_model, d_ff, bias=False)   # gate projection
        self.W2 = nn.Linear(d_ff,    d_model, bias=False) # down projection
        self.W3 = nn.Linear(d_model, d_ff,    bias=False) # up projection

    def forward(self, x):
        # Gate controls which "slots" of the expanded space pass through
        gate = F.silu(self.W1(x))    # silu = x * sigmoid(x) = swish
        up   = self.W3(x)
        return self.W2(gate * up)    # element-wise gating

d_model = 512
ffn_relu  = FFN_ReLU(d_model)
ffn_swiglu = FFN_SwiGLU(d_model)

params_relu   = sum(p.numel() for p in ffn_relu.parameters())
params_swiglu = sum(p.numel() for p in ffn_swiglu.parameters())
print(f"ReLU FFN parameters:   {params_relu:,}  (d_ff = {4*d_model})")
print(f"SwiGLU FFN parameters: {params_swiglu:,}  (d_ff = {int(2/3*4*d_model)}, 3 matrices)")

# Forward pass
x = torch.randn(4, 16, d_model)  # batch=4, seq=16, d_model=512
print(f"
Input shape:       {x.shape}")
print(f"ReLU FFN output:   {ffn_relu(x).shape}")
print(f"SwiGLU FFN output: {ffn_swiglu(x).shape}")

# Dead neuron analysis for ReLU
with torch.no_grad():
    activated = F.relu(ffn_relu.W1(x))  # (4, 16, d_ff)
    dead_pct = (activated == 0).float().mean().item()
    print(f"
ReLU: {dead_pct:.1%} of neurons are dead (zero) on this batch")

What the FFN actually computes

The attention sublayer handles where — which positions to pull information from. The FFN sublayer handles what — given the aggregated information from attention, what should this representation become? The 4× expansion acts as a bottleneck: the input is projected into a higher-dimensional space where the non-linearity can carve out complex decision boundaries, then projected back down. Think of it as: first write many competing hypotheses about what this token might mean given its context (expand), pick the best ones (non-linearity kills weak hypotheses), synthesise (compress).

Research has shown that FFN layers function as key-value memories (Geva et al., 2021). The rows of W1 are 'keys' that activate when the input matches a certain pattern. The corresponding rows of W2 are 'values' that fire when the key is activated. The FFN is literally implementing a lookup table over learned concepts.

SwiGLU: why almost everyone switched

SwiGLU replaces ReLU with a gating mechanism: output = (x·W1) ⊙ swish(x·W3), where swish(x) = x·sigmoid(x) is smooth and does not have the dead neuron problem. The gate (the SiLU/swish activation) learns which dimensions of the expanded representation to pass through versus suppress. This learnable gating is more expressive than a fixed threshold. LLaMA, Mistral, Qwen, Gemma, PaLM all use SwiGLU or a variant (GeGLU uses GELU instead of SiLU).

The parameter count stays similar despite adding a third matrix: SwiGLU uses d_ff = 2/3 × 4 × d_model (rounded to a multiple of 64 for GPU efficiency), so the third matrix (W3) is offset by the smaller d_ff. Practically, SwiGLU trains faster per loss unit and is now the default for any new model.

Visualise FFN activations: run a real LLaMA model with access to intermediate layers (via hooks in PyTorch). Collect the FFN hidden activations across a batch of sentences. Measure what fraction are near-zero — these are the 'dead' in ReLU, or near-suppressed by the SwiGLU gate. Neurons that are consistently zero across all inputs are learning nothing. This is a useful sanity check for fine-tuned models that may have overfit.

Try it interactively

GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.

Open GenAI Systems Lab →