GenAI Systems Lab Open interactive version →
AI Engineering 12 min read

Contrastive Learning From Scratch: Train Your Own Embedding Model

In-batch negatives, InfoNCE loss, temperature scaling — the full contrastive training loop in PyTorch. Understand why temperature=0.05 matters, what hard negatives are, and why the model you fine-tune on your domain will beat the generic API.

The embedding models that power semantic search — all-MiniLM-L6-v2, text-embedding-3-small, E5, BGE — were trained with contrastive learning. You feed (query, relevant document) pairs, and the training signal is: push the query embedding closer to the relevant document embedding, push it away from all other documents in the batch. No labels beyond relevance pairs. The result is an embedding space where vector similarity corresponds to semantic similarity.

In-batch negatives: the key ingredient

For each (query, positive) pair in a batch, every other document in the batch serves as a negative. A batch of 32 pairs provides 31 negatives per query without any additional data collection. You compute all embeddings once, then compute a similarity matrix in a single matmul. The training signal comes from the ratio: similarity to the positive vs. similarities to all negatives.

import torch
import torch.nn as nn
import torch.nn.functional as F

class BiEncoder(nn.Module):
    def __init__(self, vocab_size=1000, d_model=64, d_emb=32):
        super().__init__()
        self.embed   = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.project = nn.Linear(d_model, d_emb)

    def encode(self, token_ids):
        mask = (token_ids != 0).float()
        x    = self.embed(token_ids)
        x    = (x * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True)
        return F.normalize(self.project(x), dim=-1)

    def forward(self, query_ids, doc_ids):
        return self.encode(query_ids), self.encode(doc_ids)

def contrastive_loss(q_emb, d_emb, temperature=0.05):
    # (B, B) similarity matrix — diagonal is the positive pairs
    sim    = q_emb @ d_emb.T / temperature
    labels = torch.arange(sim.shape[0], device=sim.device)
    return F.cross_entropy(sim, labels)

torch.manual_seed(42)
model     = BiEncoder(vocab_size=200, d_model=32, d_emb=16)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

def make_batch(B=16, L=10, V=200):
    q = torch.randint(1, V, (B, L))
    d = torch.randint(1, V, (B, L))
    d[:, :L//2] = q[:, :L//2]    # docs share half tokens with their query
    return q, d

for step in range(500):
    q_ids, d_ids = make_batch(B=32)
    q_emb, d_emb = model(q_ids, d_ids)
    loss = contrastive_loss(q_emb, d_emb)
    optimizer.zero_grad(); loss.backward(); optimizer.step()
    if step % 100 == 0:
        with torch.no_grad():
            q_e, d_e = model(q_ids, d_ids)
            sim   = q_e @ d_e.T
            ranks = (sim > sim.diagonal().unsqueeze(1)).sum(dim=1)
            r1    = (ranks == 0).float().mean().item()
        print(f"Step {step:4d}  loss={loss.item():.4f}  Recall@1={r1:.2%}")

Temperature: the most important hyperparameter

Temperature controls how sharply the model distinguishes positives from negatives. Low temperature (0.05): softmax is very peaked, strong signal from small similarity differences. High temperature: flatter distribution, signal only from large differences. SimCSE and SBERT both found 0.05-0.07 works well. Too low → training collapse (model overconfidently pushes everything apart). Too high → signal too weak.

Hard negatives: beyond random in-batch negatives

Random in-batch negatives are easy: 'What is BERT?' vs 'What is the capital of France?' is obviously not relevant. Hard negatives are topically similar but not relevant: 'What is BERT?' vs 'What is GPT?' The model must learn fine-grained distinctions. Mining hard negatives — using an initial retriever to find high-scoring false negatives — is the key difference between a mediocre embedding model and a strong one.

Fine-tune all-MiniLM-L6-v2 on your domain using the SBERT library. Provide (query, positive) pairs from your corpus. Measure Recall@10 before and after on a held-out evaluation set. On most domain-specific corpora, you will see 5-20% improvement from even 1-2k training pairs.

Try it interactively

GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.

Open GenAI Systems Lab →