Counterfactuals Guided by Prototypes — Alibi 0.9.5 documentation (original) (raw)

Note

To enable support for Counterfactuals guided by Prototypes, you may need to run

pip install alibi[tensorflow]

Overview

This method is based on the Interpretable Counterfactual Explanations Guided by Prototypes paper which proposes a fast, model agnostic method to find interpretable counterfactual explanations for classifier predictions by using class prototypes.

Humans often think about how they can alter the outcome of a situation. What do I need to change for the bank to approve my loan? is a common example. This form of counterfactual reasoning comes natural to us and explains how to arrive at a desired outcome in an interpretable manner. Moreover, examples of counterfactual instances resulting in a different outcome can give powerful insights of what is important to the the underlying decision process. This makes it a compelling method to explain predictions of machine learning models. In the context of predictive models, a counterfactual instance describes the necessary change in input features of a test instance that alter the prediction to a predefined output (e.g. a prediction class). The counterfactual is found by iteratively perturbing the input features of the test instance during an optimization process until the desired output is achieved.

A high quality counterfactual instance \(x_{cf}\) should have the following desirable properties:

We can obtain those properties by incorporating additional loss terms in the objective function that is optimized using gradient descent. A basic loss function for a counterfactual can look like this:

\[Loss = cL_{pred} + \beta L_{1} + L_{2}\]

The first loss term, \(cL_{pred}\), encourages the perturbed instance to predict another class than the original instance. The \(\beta\)\(L_{1}\) + \(L_{2}\) terms act as an elastic net regularizer and introduce sparsity by penalizing the size of the difference between the counterfactual and the perturbed instance. While we can obtain sparse counterfactuals using this objective function, these are often not very interpretable because the training data distribution is not taken into account, and the perturbations are not necessarily meaningful.

The Contrastive Explanation Method (CEM) uses an autoencoder which is trained to reconstruct instances of the training set. We can then add the \(L_{2}\) reconstruction error of the perturbed instance as a loss term to keep the counterfactual close to the training data distribution. The loss function becomes:

\[Loss = cL_{pred} + \beta L_{1} + L_{2} + L_{AE}\]

The \(L_{AE}\) does however not necessarily lead to interpretable solutions or speed up the counterfactual search. The lack of interpretability occurs because the overall data distribution is followed, but not the class specific one. That’s where the prototype loss term \(L_{proto}\) comes in. To define the prototype for each prediction class, we can use the encoder part of the previously mentioned autoencoder. We also need the training data or at least a representative sample. We use the model to make predictions on this data set. For each predicted class, we encode the instances belonging to that class. The class prototype is simply the average of the k closest encodings in that class to the encoding of the instance that we want to explain. When we want to generate a counterfactual, we first find the nearest prototype other than the one for the predicted class on the original instance. The \(L_{proto}\) loss term tries to minimize the \(L_{2}\) distance between the counterfactual and the nearest prototype. As a result, the perturbations are guided to the closest prototype, speeding up the counterfactual search and making the perturbations more meaningful as they move towards a typical in-distribution instance. If we do not have a trained encoder available, we can build class representations using k-d trees for each class. The prototype is then the k nearest instance from a k-d tree other than the tree which represents the predicted class on the original instance. The loss function now looks as follows:

\[Loss = cL_{pred} + \beta L_{1} + L_{2} + L_{AE} + L_{proto}\]

The method allows us to select specific prototype classes to guide the counterfactual to. For example, in MNIST the closest prototype to a 9 could be a 4. However, we can specify that we want to move towards the 7 prototype and avoid 4.

In order to help interpretability, we can also add a trust score constraint on the proposed counterfactual. The trust score is defined as the ratio of the distance between the encoded counterfactual and the prototype of the class predicted on the original instance, and the distance between the encoded counterfactual and the prototype of the class predicted for the counterfactual instance. Intuitively, a high trust score implies that the counterfactual is far from the originally predicted class compared to the counterfactual class. For more info on trust scores, please check out the documentation.

Because of the \(L_{proto}\) term, we can actually remove the prediction loss term and still obtain an interpretable counterfactual. This is especially relevant for fully black box models. When we provide the counterfactual search method with a Keras or TensorFlow model, it is incorporated in the TensorFlow graph and evaluated using automatic differentiation. However, if we only have access to the model’s prediction function, the gradient updates are numerical and typically require a large number of prediction calls because of \(L_{pred}\). These prediction calls can slow the search down significantly and become a bottleneck. We can represent the gradient of the loss term as follows:

\[\frac{\partial L_{pred}}{\partial x} = \frac{\partial L_{pred}}{\partial p} \frac{\partial p}{\partial x}\]

where \(p\) is the prediction function and \(x\) the input features to optimize. For a 28 by 28 MNIST image, the \(^{\delta p}/_{\delta x}\) term alone would require a prediction call with batch size 28x28x2 = 1568. By using the prototypes to guide the search however, we can remove the prediction loss term and only make a single prediction at the end of each gradient update to check whether the predicted class on the proposed counterfactual is different from the original class.

Categorical Variables

It is crucial for many machine learning applications to deal with both continuous numerical and categorical data. Explanation methods which rely on perturbations or sampling of the input features need to make sure those perturbations are meaningful and capture the underlying structure of the data. If not done properly, the perturbed or sampled instances are possibly out of distribution compared to the training data and result in misleading explanations. The perturbation or sampling process becomes tricky for categorical features. For instance random perturbations across possible categories or enforcing a ranking between categories based on frequency of occurrence in the training data do not capture this structure.

Our method first computes the pairwise distances between categories of a categorical variable based on either the model predictions (MVDM) or the context provided by the other variables in the dataset (ABDM). For MVDM, we use the difference between the conditional model prediction probabilities of each category. This method is based on the Modified Value Difference Metric (MVDM) by Cost et al (1993). ABDM stands for_Association-Based Distance Metric_, a categorical distance measure introduced by Le et al (2005). ABDM infers context from the presence of other variables in the data and computes a dissimilarity measure based on the Kullback-Leibler divergence. Both methods can also be combined as ABDM-MVDM. We can then apply multidimensional scaling to project the pairwise distances into Euclidean space. More details will be provided in a forthcoming paper.

The different use cases are highlighted in the example notebooks linked at the bottom of the page.

Usage

Initialization

The counterfactuals guided by prototypes method works on fully black-box models. This means that they can work with arbitrary functions that take arrays and return arrays. However, if the user has access to a full TensorFlow (TF) or Keras model, this can be passed in as well to take advantage of the automatic differentiation in TF to speed up the search. This section describes the initialization for a TF/Keras model. Please see the numerical gradients section for black box models.

We first load our MNIST classifier and the (optional) autoencoder and encoder:

cnn = load_model('mnist_cnn.h5') ae = load_model('mnist_ae.h5') enc = load_model('mnist_enc.h5')

We can now initialize the counterfactual:

shape = (1,) + x_train.shape[1:] cf = CounterfactualProto(cnn, shape, kappa=0., beta=.1, gamma=100., theta=100., ae_model=ae, enc_model=enc, max_iterations=500, feature_range=(-.5, .5), c_init=1., c_steps=5, learning_rate_init=1e-2, clip=(-1000., 1000.), write_dir='./cf')

Besides passing the predictive, and (optional) autoencoder and models, we set a number of hyperparameters

general:

cfplossopt

… related to the optimizer:

… related to the objective function:

When the dataset contains categorical variables, we need to additionally pass the following arguments:

It is also important to remember that the perturbations are applied in the numerical feature space, after the categorical variables have been transformed into numerical features. This has to be reflected by the dimension and values of feature_range. Imagine for example that we have a dataset with 10 columns. Two of the features are categorical and one-hot encoded. They can both take 3 values each. As a result, the number of columns in the dataset is reduced to 6 when we transform those categorical features to numerical features. As a result, the feature_range needs to contain the upper and lower ranges for 6 features.

While the default values for the loss term coefficients worked well for the simple examples provided in the notebooks, it is recommended to test their robustness for your own applications.

Warning

Once a CounterfactualProto instance is initialized, the parameters of it are frozen even if creating a new instance. This is due to TensorFlow behaviour which holds on to some global state. In order to change parameters of the explainer in the same session (e.g. for explaining different models), you will need to reset the TensorFlow graph manually:

import tensorflow as tf tf.keras.backend.clear_session()

You may need to reload your model after this. Then you can create a new CounterfactualProto instance with new parameters.

Fit

If we use an encoder to find the class prototypes, we need an additional fit step on the training data:

We also need the fit step if the data contains categorical features so we can compute the numerical transformations. In practice, most of these optional arguments do not need to be changed.

cf.fit(x_train, d_type='abdm', w=None, disc_perc=[25, 50, 75], standardize_cat_vars=False, smooth=1., center=True, update_feature_range=True)

Explanation

We can now explain the instance:

explanation = cf.explain(X, Y=None, target_class=None, k=20, k_type='mean', threshold=0., verbose=True, print_every=100, log_every=100)

The explain method returns an Explanation object with the following attributes:

Numerical Gradients

So far, the whole optimization problem could be defined within the TF graph, making automatic differentiation possible. It is however possible that we do not have access to the model architecture and weights, and are only provided with a predict function returning probabilities for each class. The counterfactual can then be initialized in the same way:

define model

cnn = load_model('mnist_cnn.h5') predict_fn = lambda x: cnn.predict(x) ae = load_model('mnist_ae.h5') enc = load_model('mnist_enc.h5')

initialize explainer

shape = (1,) + x_train.shape[1:] cf = CounterfactualProto(predict_fn, shape, gamma=100., theta=100., ae_model=ae, enc_model=enc, max_iterations=500, feature_range=(-.5, .5), c_init=1., c_steps=4, eps=(1e-2, 1e-2), update_num_grad=100)

In this case, we need to evaluate the gradients of the loss function with respect to the input features numerically:

\[\frac{\partial L_{pred}}{\partial x} = \frac{\partial L_{pred}}{\partial p} \frac{\partial p}{\partial x}\]

where \(L_{pred}\) is the loss term related to the prediction function, \(p\) is the prediction function and \(x\) are the input features to optimize. There are now 2 additional hyperparameters to consider:

eps0 = np.array([[1e-2, 1e-2, 1e-2]]) # 3 prediction categories, equivalent to 1e-2 eps1 = np.array([[1e-2, 1e-2, 1e-2, 1e-2]]) # 4 features, also equivalent to 1e-2 eps = (eps0, eps1)

We can also remove the prediction loss term by setting c_init to 0 and only run 1 c_steps, and still obtain an interpretable counterfactual. This dramatically speeds up the counterfactual search (e.g. by 100x in the MNIST example notebook):

cf = CounterfactualProto(predict_fn, shape, gamma=100., theta=100., ae_model=ae, enc_model=enc, max_iterations=500, feature_range=(-.5, .5), c_init=0., c_steps=1)

k-d trees

So far, we assumed that we have a trained encoder available to find the nearest class prototype. This is however not a hard requirement. As mentioned in the Overview section, we can use k-d trees to build class representations, find prototypes by querying the trees for each class and return the k nearest class instance as the closest prototype. We can run the counterfactual as follows:

cf = CounterfactualProto(cnn, shape, use_kdtree=True, theta=10., feature_range=(-.5, .5)) cf.fit(x_train, trustscore_kwargs=None) explanation = cf.explain(X, k=2)

Examples

Counterfactuals guided by prototypes on MNIST

Counterfactuals guided by prototypes on California housing dataset

Counterfactual explanations with one-hot encoded categorical variables

Counterfactual explanations with ordinally encoded categorical variables