Speculative Decoding: How to Get Multiple Tokens Per Forward Pass
A small draft model generates k candidate tokens. The large target model verifies all k in one parallel forward pass using rejection sampling. When all k are accepted, you generated k+1 tokens for the cost of one. The algorithm, acceptance rate math, 2-4x throughput gains, and what makes a good draft model.
Autoregressive LLM generation is sequential by construction: you cannot generate token t+1 until you have generated token t. This means you cannot parallelise generation across the time dimension. Or can you? Speculative decoding breaks this constraint using a clever insight: a small fast 'draft' model generates several candidate tokens, and the large 'target' model verifies them all in a single parallel forward pass. When the draft is right — which it usually is on easy tokens — you get multiple tokens per target model call.
The algorithm
Step 1: Use a small draft model (e.g., LLaMA-3-8B) to greedily generate k tokens. Step 2: Run the large target model (e.g., LLaMA-3-70B) on the original context plus all k draft tokens simultaneously. The target model produces k+1 output distributions in a single forward pass (because it can process all positions in parallel with the full K/V). Step 3: Accept or reject each draft token based on a rejection sampling criterion. If all k tokens are accepted, you also take the last target distribution's sample. If a draft token is rejected, you sample from the corrected distribution and discard subsequent draft tokens.
import numpy as np
def rejection_sampling(p_target, p_draft, draft_token):
"""
Accept draft token sampled from p_draft if it fits p_target.
The acceptance rule: accept with probability min(1, p_target / p_draft).
If rejected, sample from the adjusted distribution p_corrected.
This guarantees the output distribution matches p_target exactly.
"""
accept_prob = min(1.0, p_target[draft_token] / (p_draft[draft_token] + 1e-12))
if np.random.random() < accept_prob:
return draft_token, True # accepted
# Rejection: sample from corrected distribution
# p_corrected ∝ max(0, p_target - p_draft)
p_corrected = np.maximum(0, p_target - p_draft)
if p_corrected.sum() > 0:
p_corrected /= p_corrected.sum()
corrected_token = np.random.choice(len(p_target), p=p_corrected)
else:
corrected_token = np.argmax(p_target)
return corrected_token, False # rejected, use corrected token
def simulate_speculative_decoding(vocab_size=50, k=4, n_steps=10, draft_accuracy=0.7):
"""Simulate the accept/reject loop and count effective tokens per target call."""
total_target_calls = 0
total_tokens_generated = 0
for step in range(n_steps):
# Target generates k+1 distributions in one call
total_target_calls += 1
# Draft produces k tokens (simplified: accept with draft_accuracy probability)
accepted_count = 0
for i in range(k):
if np.random.random() < draft_accuracy:
accepted_count += 1
else:
break # once a draft token is rejected, stop
# We get accepted_count + 1 tokens (last one from target distribution)
tokens_this_step = accepted_count + 1
total_tokens_generated += tokens_this_step
tokens_per_call = total_tokens_generated / total_target_calls
print(f"k={k}, draft_accuracy={draft_accuracy:.0%}")
print(f" Tokens per target call: {tokens_per_call:.2f} (vs 1.0 baseline)")
print(f" Theoretical max: {k+1} tokens per call if all k accepted")
print(f" Speedup (ignoring draft overhead): ~{tokens_per_call:.1f}x")
for acc in [0.5, 0.7, 0.85]:
simulate_speculative_decoding(k=5, draft_accuracy=acc)
print()
Why acceptance rate matters more than k
If the draft model has 70% per-token acceptance rate and k=5, the expected number of accepted draft tokens follows a geometric-like distribution: P(at least 5 accepted) = 0.7^5 = 0.17, P(at least 4) = 0.7^4 = 0.24, and so on. The expected tokens per target call is Σ_{i=0}^{k} (acceptance_rate^i) = (1 - acceptance_rate^(k+1)) / (1 - acceptance_rate). For 70% acceptance rate: ~2.7 tokens per target call vs 1.0 without speculative decoding. At 85%: ~3.9 tokens per call. The draft model quality is the dominant factor.
What pairs well
Draft and target models must share the same tokenizer and vocabulary — the draft produces token indices that the target verifies. The most effective pairings are models from the same family: LLaMA-3-8B drafting for LLaMA-3-70B (same vocabulary, similar distribution), or even smaller scratch models (LLaMA-3-1B) drafting for larger ones. Self-speculation uses a single model at multiple draft depths; Medusa adds parallel 'heads' to the model itself that produce draft tokens.
Speculative decoding is most valuable when target model latency is the bottleneck and draft model compute is cheap relative to memory bandwidth. In practice, it provides 2-4× throughput improvements on text generation workloads with high draft accuracy, with no change to output quality (rejection sampling guarantees exact distribution match).
Measure draft acceptance rate on your production query distribution. Take 100 representative queries. For each query, run your target model greedily and your draft model greedily, compare token-by-token. The fraction of positions where they agree is your expected acceptance rate. If it is below 60%, speculative decoding may not be worth the infrastructure complexity.
- Fast Inference from Transformers via Speculative Decoding — Leviathan et al. (2023)
- Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Try it interactively
GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.
Open GenAI Systems Lab →