Classifier — alibi-detect 0.12.0 documentation (original) (raw)

Overview

The classifier-based drift detector Lopez-Paz and Oquab, 2017 simply tries to correctly distinguish instances from the reference set vs. the test set. The classifier is trained to output the probability that a given instance belongs to the test set. If the probabilities it assigns to unseen test instances are significantly higher (as determined by a Kolmogorov-Smirnov test) to those it assigns to unseen reference instances then the test set must differ from the reference set and drift is flagged. Alternatively, the detector also allows to binarize the classifier predictions (0 or 1) and apply a binomial test on the binarized predictions of the reference vs. the test data. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used for the significance test. Note that a new classifier is trained for each test set or even each fold within the test set.

Usage

Initialize

Arguments:

Keyword arguments:

Additional PyTorch keyword arguments:

Additional Sklearn keyword arguments:

Initialized TensorFlow drift detector example:

import tensorflow as tf from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input from alibi_detect.cd import ClassifierDrift

model = tf.keras.Sequential( [ Input(shape=(32, 32, 3)), Conv2D(8, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2D(16, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2D(32, 4, strides=2, padding='same', activation=tf.nn.relu), Flatten(), Dense(2, activation='softmax') ] )

cd = ClassifierDrift(x_ref, model, p_val=.05, preds_type='probs', n_folds=5, epochs=2)

A similar detector using PyTorch:

import torch.nn as nn

model = nn.Sequential( nn.Conv2d(3, 8, 4, stride=2, padding=0), nn.ReLU(), nn.Conv2d(8, 16, 4, stride=2, padding=0), nn.ReLU(), nn.Conv2d(16, 32, 4, stride=2, padding=0), nn.ReLU(), nn.Flatten(), nn.Linear(128, 2) )

cd = ClassifierDrift(x_ref, model, backend='pytorch', p_val=.05, preds_type='logits')

Detect Drift

We detect data drift by simply calling predict on a batch of instances x. return_p_val equal to True will also return the p-value of the test, return_distance equal to True will return a notion of strength of the drift and return_probs equals True also returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) as well as the associated out-of-fold reference and test instances.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

Examples

Drift detection on CIFAR10

Drift detection on Adult Census

Drift detection on Amazon reviews