Context-Aware Maximum Mean Discrepancy — alibi-detect 0.12.0 documentation (original) (raw)

source

Overview

The context-aware maximum mean discrepancy drift detector (Cobb and Van Looveren, 2022) is a kernel based method for detecting drift in a manner that can take relevant context into account. A normal drift detector detects when the distributions underlying two sets of samples \(\{x^0_i\}_{i=1}^{n_0}\) and \(\{x^1_i\}_{i=1}^{n_1}\) differ. A context-aware drift detector only detects differences that can not be attributed to a corresponding difference between sets of associated context variables \(\{c^0_i\}_{i=1}^{n_0}\) and \(\{c^1_i\}_{i=1}^{n_1}\).

Context-aware drift detectors afford practitioners the flexibility to specify their desired context variable. It could be a transformation of the data, such as a subset of features, or an unrelated indexing quantity, such as the time or weather. Everything that the practitioner wishes to allow to change between the reference window and test window should be captured within the context variable.

On a technical level, the method operates in a manner similar to the maximum mean discrepancy detector. However, instead of using an estimate of the squared difference between kernel mean embeddings of \(X_{\text{ref}}\) and \(X_{\text{test}}\) as the test statistic, we now use an estimate of the expected squared difference between the kernel conditional mean embeddings of \(X_{\text{ref}}|C\) and\(X_{\text{test}}|C\). As well as the kernel defined on the space of data \(X\) required to define the test statistic, estimating the statistic additionally requires a kernel defined on the space of the context variable \(C\). For any given realisation of the test statistic an associated p-value is then computed using a conditional permutation test.

The detector is designed for cases where the training data contains a rich variety of contexts and individual test windows may cover a much more limited subset. It is assumed that the test contexts remain within the support of those observed in the reference set.

Usage

Initialize

Arguments:

Keyword arguments:

Additional PyTorch keyword arguments:

Initialized drift detector example with the PyTorch backend:

from alibi_detect.cd import ContextMMDDrift

cd = ContextMMDDrift(x_ref, c_ref, p_val=.05, backend='pytorch')

The same detector in TensorFlow:

from alibi_detect.cd import ContextMMDDrift

cd = ContextMMDDrift(x_ref, c_ref, p_val=.05, backend='tensorflow')

Detect Drift

We detect data drift by simply calling predict on a batch of test or deployment instances x and contexts c. We can return the p-value and the threshold of the permutation test by setting return_p_val to True and the context-aware maximum mean discrepancy metric and threshold by setting return_distance to True. We can also set return_coupling to True which additionally returns the coupling matrices \(W_\text{ref,test}\), \(W_\text{ref,ref}\) and\(W_\text{test,test}\). As illustrated in the examples (text, ECGs) this can provide deep insights into where the reference and test distributions are similar and where they differ.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

preds = cd.predict(x, c, return_p_val=True, return_distance=True, return_coupling=True)

Examples

Text

Context-aware drift detection on news articles

Time series

Context-aware drift detection on ECGs