GenAI Systems Lab Open interactive version →
AI Engineering 12 min read

Two-Tower Training From Scratch: The Architecture Behind Every Large-Scale Retrieval System

One tower for queries, one for items. Joint training with in-batch negatives. Item embeddings pre-computed and indexed in FAISS. The false negative problem, hard negative mining, and feature engineering rules for each tower. PyTorch implementation.

The two-tower model is the dominant retrieval architecture at scale. Netflix, Google, Spotify, LinkedIn — the candidate retrieval stage of every large recommendation system uses two-tower. One tower encodes queries (or users); the other encodes documents (or items). The inner product between their outputs is the relevance score. The towers are trained jointly with in-batch negatives. Understanding the training loop means you can fine-tune retrieval for your domain, diagnose retrieval failures, and design the feature engineering that makes your towers work.

Architecture and the serving trick

The key serving property: item (document) embeddings are computed once at indexing time and stored in a FAISS index. At query time, you only need to encode the query (one forward pass through the query tower) and run ANN search. Contrast with cross-encoders, which must encode (query, document) pairs at query time — O(N) forward passes vs. O(1) plus ANN lookup.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Tower(nn.Module):
    """Query or item tower. In practice: replace with a pre-trained transformer."""
    def __init__(self, input_dim, hidden_dim=128, output_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )
    def forward(self, x):
        return F.normalize(self.net(x), dim=-1)   # unit-norm embeddings

class TwoTowerModel(nn.Module):
    def __init__(self, query_dim, item_dim, embed_dim=64):
        super().__init__()
        self.query_tower = Tower(query_dim, output_dim=embed_dim)
        self.item_tower  = Tower(item_dim,  output_dim=embed_dim)

    def forward(self, query_features, item_features):
        q_emb = self.query_tower(query_features)   # (B, embed_dim)
        i_emb = self.item_tower(item_features)     # (B, embed_dim)
        return q_emb, i_emb


def two_tower_loss(q_emb, i_emb, temperature=0.05):
    """
    In-batch negatives: for each (query_i, item_i) pair,
    all other items in the batch are negatives.
    Exact same formulation as contrastive learning.
    """
    sim    = q_emb @ i_emb.T / temperature     # (B, B) — diagonal = positives
    labels = torch.arange(sim.shape[0], device=sim.device)
    return F.cross_entropy(sim, labels)


# ── Training ──────────────────────────────────────────────────────────────────
torch.manual_seed(42)
B = 64             # batch size — larger batches = more negatives = stronger signal
query_dim, item_dim = 32, 48
model     = TwoTowerModel(query_dim, item_dim, embed_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def make_batch(B, query_dim, item_dim, correlation=0.7):
    """Simulate paired (query, item) features with some correlation."""
    shared = torch.randn(B, min(query_dim, item_dim) // 2)
    queries = torch.cat([shared + 0.1*torch.randn_like(shared), torch.randn(B, query_dim - shared.shape[1])], dim=1)
    items   = torch.cat([shared + 0.1*torch.randn_like(shared), torch.randn(B, item_dim  - shared.shape[1])], dim=1)
    return queries, items

losses = []
for step in range(500):
    q_feat, i_feat = make_batch(B, query_dim, item_dim)
    q_emb, i_emb  = model(q_feat, i_feat)
    loss = two_tower_loss(q_emb, i_emb)
    optimizer.zero_grad(); loss.backward(); optimizer.step()
    losses.append(loss.item())
    if step % 100 == 0:
        with torch.no_grad():
            sim   = q_emb @ i_emb.T
            ranks = (sim > sim.diagonal().unsqueeze(1)).sum(dim=1)
            r1    = (ranks == 0).float().mean().item()
        print(f"Step {step:4d}  loss={loss.item():.4f}  Recall@1={r1:.2%}")

print(f"
Final loss: {losses[-1]:.4f}")

In-batch false negatives

If item_j is genuinely relevant to query_i (but is being used as a negative in the batch), the model gets a noisy gradient — it is being pushed away from a relevant item. This 'false negative' problem degrades training quality. Mitigation: filter known positive pairs from the negatives before computing the loss. Or use larger batches (with more negatives, any single false negative has less influence). Or use hard negative mining (explicit non-relevant items) rather than in-batch random negatives.

Feature engineering for the towers

Query tower features: query text (tokenised), session context, user history (past clicks, history embeddings), user demographics. Item tower features: item text (title, description), item category, item popularity, item freshness, item embeddings from a separate embedding model. The critical rule: any feature that requires interaction between query and item (e.g., 'has user clicked this item before?') cannot go into either tower — towers must be computed independently. Interaction features belong in the reranker.

Fine-tune a two-tower retrieval model on your domain: take a pre-trained bi-encoder (e.g., sentence-transformers/all-MiniLM-L6-v2), replace the final projection layer with a LoRA adapter, and train with in-batch negatives from your (query, relevant document) pairs. This is domain adaptation for retrieval — the same concept as contrastive-learning-from-scratch, applied to a pre-trained model that already knows language.

Try it interactively

GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.

Open GenAI Systems Lab →