GenAI Systems Lab Open interactive version →
Production & LLMOps 10 min read

Drift Detection in Production ML: PSI, KS Tests, and MMD Explained

Data drift, concept drift, and label drift are three different problems requiring different fixes. PSI for feature monitoring, KS for continuous distributions, MMD for embedding spaces. When to trigger retraining and what the score thresholds actually mean.

Why Your Model Breaks Without Warning

A model that achieves 92% accuracy in offline evaluation can silently degrade to 78% in production over six months — without a single error being thrown. The cause is drift: the statistical properties of the real world changed while your model stayed fixed.

There are three distinct drift phenomena, each requiring different detection methods and different remediation strategies. Confusing them leads to the wrong fix applied to the wrong problem.

Three Kinds of Drift

Data drift (covariate shift): P(X) changes but P(Y|X) is unchanged. Your inputs shifted — different user demographics, new device types, a language model upstream changed its output format. The relationship between features and labels is still valid; the distribution of inputs changed.

Concept drift: P(Y|X) changes. The meaning of your labels changed. 'Spam' in 2020 doesn't look like spam in 2024. A classifier trained on one year of emails learns outdated patterns. The inputs might be similar; what they signify changed.

Label drift (prior probability shift): P(Y) changes. The base rate of the thing you're predicting shifted. If fraud goes from 0.1% to 0.5% of transactions, a model calibrated for 0.1% will be systematically miscalibrated even if the underlying fraud patterns are unchanged.

Statistical Tests for Input Drift

import numpy as np
from scipy import stats

# Population Stability Index — industry standard for feature drift
def psi(expected: np.ndarray, actual: np.ndarray, bins: int = 10) -> float:
    """
    PSI < 0.1: no drift
    PSI 0.1–0.2: moderate drift, investigate
    PSI > 0.2: significant drift, act
    """
    expected_perc = np.histogram(expected, bins=bins)[0] / len(expected) + 1e-8
    actual_perc   = np.histogram(actual,   bins=bins)[0] / len(actual)   + 1e-8
    return np.sum((actual_perc - expected_perc) * np.log(actual_perc / expected_perc))

# Kolmogorov-Smirnov test — for continuous features
def ks_drift(baseline: np.ndarray, production: np.ndarray) -> dict:
    stat, p_value = stats.ks_2samp(baseline, production)
    return {"ks_stat": round(stat, 4), "p_value": round(p_value, 4), "drifted": p_value < 0.05}

# Chi-squared test — for categorical features
def chi2_drift(baseline_counts: np.ndarray, production_counts: np.ndarray) -> dict:
    # Normalize to same total
    expected = baseline_counts / baseline_counts.sum() * production_counts.sum()
    stat, p_value = stats.chisquare(production_counts, expected)
    return {"chi2": round(stat, 4), "p_value": round(p_value, 4), "drifted": p_value < 0.05}

# Maximum Mean Discrepancy — kernel-based, handles high-dimensional embeddings
def mmd(x: np.ndarray, y: np.ndarray, sigma: float = 1.0) -> float:
    """Gaussian kernel MMD. Use for embedding/vector drift detection."""
    def rbf(a, b):
        pairwise = np.sum((a[:, None] - b[None, :]) ** 2, axis=-1)
        return np.exp(-pairwise / (2 * sigma ** 2))
    return rbf(x, x).mean() - 2 * rbf(x, y).mean() + rbf(y, y).mean()

Monitoring Output Distributions

When ground truth labels are delayed or unavailable, monitor proxy signals instead. For a ranking model: track CTR per position bucket, average rank of purchased items, zero-result rate. For a classification model: track score distribution, top class rate, entropy of the softmax output.

# Score distribution monitoring with control charts
import pandas as pd

class ScoreMonitor:
    def __init__(self, baseline_scores: np.ndarray, window: int = 1000):
        self.baseline_mean = baseline_scores.mean()
        self.baseline_std  = baseline_scores.std()
        self.window = window
        self.buffer = []
    
    def add(self, score: float):
        self.buffer.append(score)
        if len(self.buffer) >= self.window:
            self._check()
            self.buffer = self.buffer[self.window // 2:]  # sliding window
    
    def _check(self):
        recent = np.array(self.buffer)
        z = (recent.mean() - self.baseline_mean) / (self.baseline_std / np.sqrt(len(recent)))
        psi_val = psi(
            np.random.normal(self.baseline_mean, self.baseline_std, 5000),
            recent
        )
        print(f"Z-score: {z:.2f}, PSI: {psi_val:.3f}")
        if abs(z) > 3.0:
            print("ALERT: Score mean shifted > 3 sigma")
        if psi_val > 0.2:
            print("ALERT: PSI > 0.2 — significant drift")

The Monitoring Stack

Retraining Triggers

Rule of thumb: drift in inputs warrants investigation; drift in outputs warrants retraining; drift in labels after retraining suggests a concept drift that data collection alone cannot fix — you need to rethink the label schema.

Triggered retraining beats scheduled retraining because the model is retrained when it's needed, not on a fixed calendar. Common triggers: PSI > 0.2 on a key feature for 3 consecutive days; accuracy drop > 5% absolute vs. baseline on a held-out validation set; score entropy crosses a threshold (model is more uncertain on recent data than on training data).

Try it interactively

GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.

Open GenAI Systems Lab →