Drift Detection in Production ML: PSI, KS Tests, and MMD Explained
Data drift, concept drift, and label drift are three different problems requiring different fixes. PSI for feature monitoring, KS for continuous distributions, MMD for embedding spaces. When to trigger retraining and what the score thresholds actually mean.
Why Your Model Breaks Without Warning
A model that achieves 92% accuracy in offline evaluation can silently degrade to 78% in production over six months — without a single error being thrown. The cause is drift: the statistical properties of the real world changed while your model stayed fixed.
There are three distinct drift phenomena, each requiring different detection methods and different remediation strategies. Confusing them leads to the wrong fix applied to the wrong problem.
Three Kinds of Drift
Data drift (covariate shift): P(X) changes but P(Y|X) is unchanged. Your inputs shifted — different user demographics, new device types, a language model upstream changed its output format. The relationship between features and labels is still valid; the distribution of inputs changed.
Concept drift: P(Y|X) changes. The meaning of your labels changed. 'Spam' in 2020 doesn't look like spam in 2024. A classifier trained on one year of emails learns outdated patterns. The inputs might be similar; what they signify changed.
Label drift (prior probability shift): P(Y) changes. The base rate of the thing you're predicting shifted. If fraud goes from 0.1% to 0.5% of transactions, a model calibrated for 0.1% will be systematically miscalibrated even if the underlying fraud patterns are unchanged.
Statistical Tests for Input Drift
import numpy as np
from scipy import stats
# Population Stability Index — industry standard for feature drift
def psi(expected: np.ndarray, actual: np.ndarray, bins: int = 10) -> float:
"""
PSI < 0.1: no drift
PSI 0.1–0.2: moderate drift, investigate
PSI > 0.2: significant drift, act
"""
expected_perc = np.histogram(expected, bins=bins)[0] / len(expected) + 1e-8
actual_perc = np.histogram(actual, bins=bins)[0] / len(actual) + 1e-8
return np.sum((actual_perc - expected_perc) * np.log(actual_perc / expected_perc))
# Kolmogorov-Smirnov test — for continuous features
def ks_drift(baseline: np.ndarray, production: np.ndarray) -> dict:
stat, p_value = stats.ks_2samp(baseline, production)
return {"ks_stat": round(stat, 4), "p_value": round(p_value, 4), "drifted": p_value < 0.05}
# Chi-squared test — for categorical features
def chi2_drift(baseline_counts: np.ndarray, production_counts: np.ndarray) -> dict:
# Normalize to same total
expected = baseline_counts / baseline_counts.sum() * production_counts.sum()
stat, p_value = stats.chisquare(production_counts, expected)
return {"chi2": round(stat, 4), "p_value": round(p_value, 4), "drifted": p_value < 0.05}
# Maximum Mean Discrepancy — kernel-based, handles high-dimensional embeddings
def mmd(x: np.ndarray, y: np.ndarray, sigma: float = 1.0) -> float:
"""Gaussian kernel MMD. Use for embedding/vector drift detection."""
def rbf(a, b):
pairwise = np.sum((a[:, None] - b[None, :]) ** 2, axis=-1)
return np.exp(-pairwise / (2 * sigma ** 2))
return rbf(x, x).mean() - 2 * rbf(x, y).mean() + rbf(y, y).mean()
Monitoring Output Distributions
When ground truth labels are delayed or unavailable, monitor proxy signals instead. For a ranking model: track CTR per position bucket, average rank of purchased items, zero-result rate. For a classification model: track score distribution, top class rate, entropy of the softmax output.
# Score distribution monitoring with control charts
import pandas as pd
class ScoreMonitor:
def __init__(self, baseline_scores: np.ndarray, window: int = 1000):
self.baseline_mean = baseline_scores.mean()
self.baseline_std = baseline_scores.std()
self.window = window
self.buffer = []
def add(self, score: float):
self.buffer.append(score)
if len(self.buffer) >= self.window:
self._check()
self.buffer = self.buffer[self.window // 2:] # sliding window
def _check(self):
recent = np.array(self.buffer)
z = (recent.mean() - self.baseline_mean) / (self.baseline_std / np.sqrt(len(recent)))
psi_val = psi(
np.random.normal(self.baseline_mean, self.baseline_std, 5000),
recent
)
print(f"Z-score: {z:.2f}, PSI: {psi_val:.3f}")
if abs(z) > 3.0:
print("ALERT: Score mean shifted > 3 sigma")
if psi_val > 0.2:
print("ALERT: PSI > 0.2 — significant drift")
The Monitoring Stack
- Log predictions with timestamps, feature snapshots, and any available feedback signal. Store in a time-series-aware store (ClickHouse, Redshift, BigQuery).
- Run statistical tests on sliding windows — e.g. PSI on daily feature distributions vs. baseline week. Alert on thresholds, not just on errors.
- Track label distribution when feedback is available. Even soft signals — thumbs up/down, dwell time, return visits — are better than nothing.
- Segment by cohort. Drift in the whole population can mask severe drift in a slice (a user segment, a geography, a device type). Run per-segment monitoring for high-value populations.
- Ground truth delay is unavoidable in many settings. Design your monitoring to give leading indicators (feature drift, score drift) before the lagging indicator (accuracy drop) arrives.
Retraining Triggers
Rule of thumb: drift in inputs warrants investigation; drift in outputs warrants retraining; drift in labels after retraining suggests a concept drift that data collection alone cannot fix — you need to rethink the label schema.
Triggered retraining beats scheduled retraining because the model is retrained when it's needed, not on a fixed calendar. Common triggers: PSI > 0.2 on a key feature for 3 consecutive days; accuracy drop > 5% absolute vs. baseline on a held-out validation set; score entropy crosses a threshold (model is more uncertain on recent data than on training data).
Try it interactively
GenAI Systems Lab is a free platform for AI engineers — configure real failure modes, break things, and build the judgment that gets you hired.
Open GenAI Systems Lab →