ProtoSelect — Alibi 0.9.7.dev0 documentation (original) (raw)

Overview

Bien and Tibshirani (2012) proposed ProtoSelect, which is a prototype selection method with the goal of constructing not only a condensed view of a dataset but also an interpretable model (applicable to classification only). Prototypes can be defined as instances that are representative of the entire training data distribution. Formally, consider a dataset of training points \(\mathcal{X} = \{x_1, ..., x_n \} \subset \mathbf{R}^p\) and their corresponding labels \(\mathcal{Y} = \{y_1, ..., y_n\}\), where \(y_i \in \{1, 2, ..., L\}\). ProtoSelect finds sets \(\mathcal{P}_{l} \subseteq \mathcal{X}\) for each class \(l\) such that the set union of \(\mathcal{P}_1, \mathcal{P}_2, ..., \mathcal{P}_L\) would provided a distilled view of the training dataset \((\mathcal{X}, \mathcal{Y})\).

Given the sets of prototypes, one can construct a simple interpretable classifier given by:

\[\hat{c}(x) = \underset{l}{\operatorname{argmin}} \min_{z \in \mathcal{P}_l} d(x, z)\]

Note that the classifier defined in the equation above would be equivalent to 1-KNN if each set \(\mathcal{P}_l\) would consist only of instances belonging to class \(l\).

ProtoSelect method

ProtoSelect is designed such that each prototype would satisfy a set of desired properties. For a set \(\mathcal{P}_l \subseteq \mathcal{X}\), the neighborhood of a point \(x_i \in \mathcal{P}_l\) is given by the points contained in an \(\epsilon\)-ball centered in \(x_i\), denoted as \(B(x_i, \epsilon)\). Thus, given a radius \(\epsilon\) for a point \(x_i\), we say that another point \(x_j\) is covered by \(x_i\) if \(x_j\) is contained in the\(\epsilon\)-ball centered on \(x_i\). A visualization of the prototypes sets for various \(\epsilon\) radius values are depicted in the following figure:

ProtoSelect

Bien and Tibshirani, PROTOTYPE SELECTION FOR INTERPRETABLE CLASSIFICATION, 2012

A desirable prototype set for a class \(l\) would satisfy the following properties:

Formally, let us first define \(\alpha_{j}^{(l)} \in \{0, 1\}\) to indicate whether we select \(x_j\) to be in \(\mathcal{P}_l\). Then we can write the three properties as an integer program as follows:

\begin{equation} \min_{\alpha_{j}^{(l)}, \xi_{i}, \nu_{i}} \sum_{i}{\xi_i} + \sum_{i}{\nu_i} + \lambda \sum_{j, l}\alpha_{j}^{(l)} \text{ such that} \end{equation}

\begin{equation} \sum_{j: x_i \in B(x_j, \epsilon)} \alpha_{j}^{(y_i)} \ge 1 - \xi_i, \forall x_i \in \mathcal{X}, \tag{a} \end{equation}

\begin{equation} \sum_{j: x_i \in B(x_j, \epsilon), l \neq y_i} \alpha_{j}^{(l)} \le 0 + \nu_i, \forall x_i \in \mathcal{X}, \tag{b} \end{equation}

\begin{equation} \alpha_{j}^{l} \in \{0, 1\} \forall j, l, \end{equation}

\begin{equation} \xi_i, \nu_i \ge 0. \end{equation}

For each training point \(x_i\), we introduce the slack variables \(\xi_i\) and \(\nu_i\). Before explaining the two constraints, note that \(\sum_{j: x_i \in B(x_j, \epsilon)} \alpha_{j}^{(l)}\) counts the number of balls \(B(x_j, \epsilon)\) with \(x_j \in \mathcal{P}_l\) that cover the point \(x_i\). The constraint (a) tries to encourage that for each training point \((x_i, y_i)\), \(x_i\) is covered in at least one \(\epsilon\)-ball of a prototype for the class \(y_i\). On the other hand, the constraint (b) tries to encourage that \(x_i\) will not belong to any \(\epsilon\)-ball centered in a prototype for the other classes \(l \ne y_i\).

Because the integer program defined above cannot be solved in polynomial time, the authors propose two alternative solution. The first one consists of a relaxation of the objective and a transformation of the integer program into a linear program, for which post-processing is required to ensure feasibility of the solution. We refer the reader to the paper for more details. The second one, recommended and implemented in Alibi, follows a greedy approach. Given the current choice of prototypes subsets \((\mathcal{P}_1, ..., \mathcal{P}_L)\), in the next iteration we update it to \((\mathcal{P}_1, ..., \mathcal{P}_{l} \cup \{x_j\}, ..., \mathcal{P}_L)\), where \(x_j\) is selected such that it maximizes the objective \(\Delta Obj(x_j, l) = \Delta \xi(x_j,l) - \Delta\nu(x_j, l) - \lambda\), where:

\begin{equation} \Delta \xi(x_j, l) = \lvert \mathcal{X}_l \cap (B(x_j, \epsilon) \setminus \cup_{x_{j^\prime} \in \mathcal{P}_l}B(x_{j^\prime}, \epsilon)) \rvert \tag{a} \end{equation}

\begin{equation} \Delta \nu(x_j, l) = \lvert B(x_j, \epsilon) \cap (\mathcal{X} \setminus \mathcal{X}_l) \rvert \tag{b}. \end{equation}

Note that \(\Delta \xi(x_j, l)\) counts the number of new instances (i.e. not already covered by the existing prototypes) belonging to class \(l\) that \(x_j\) covers in the \(\epsilon\)-ball. On the other hand, \(\Delta \nu(x_j, l)\) counts how many instances belonging to a different class than \(l\) the \(x_j\) element covers. Finally, \(\lambda\) is the penalty/cost of adding a new prototypes encouraging sparsity (lower number of prototypes). Intuitively, a good prototype for a class \(l\) will cover as many new instances belonging to class \(l\) (i.e. maximize \(\Delta \xi(x_j, l)\)) and avoid covering elements outside the class \(l\) (i.e. minimize \(\Delta \nu(x_j, l)\)). The prototype selection algorithm stops when all \(\Delta Obj(x_j, l)\) are lower than 0.

Usage

from alibi.prototypes import ProtoSelect from alibi.utils.kernel import EuclideanDistance

summariser = ProtoSelect(kernel_distance=EuclideanDistance(), eps=eps, preprocess_fn=preprocess_fn)

Following the initialization, we need to fit the summariser.

summariser = summariser.fit(X=X_train, y=y_train)

In a more general case, we can specify an optional dataset \(Z\) to choose the prototypes from (see the documentation of the fit method). In this scenario, the dataset to be summarised is still \(X\), but it is summarised by prototypes belonging to the dataset \(Z\). Furthermore, note that we only need to specify the labels for the \(X\) set through \(y\), but not for \(Z\). In case the labels \(y\) are missing, the method implicitly assumes that all the instances belong to the same class. This means that the second term in the objective, \(\Delta\nu(x_j, l)\), will be 0. Thus, the algorithm will try to find prototypes that cover as many data instances as possible, with minimum overlap between their corresponding \(\epsilon\)-balls.

Finally, we can obtain a summary by requesting the maximum number of prototypes to be returned:

summary = summariser.summarise(num_prototypes=num_prototypes)

As we previously mentioned, the algorithm stops when the objective is less than 0, for all the remaining instances in the set of potential prototypes. This means that the algorithm can return a lower number of prototypes than the one requested.

Another important observation is that the summary returns the prototypes with their corresponding labels although no labels were provided for \(Z\). This is possible since each prototype \(z\) will belong to a prototype set \(\mathcal{P}_l\), and thus we can assign a label \(l\) to \(z\). Following the summarisation step, one can train an interpretable 1-KNN classifier on the returned prototypes even for an unlabeled dataset \(Z\).

Warning

If the optional argument \(Z\) is not provided, it is automatically set to \(X\). Although the labels of the data instances belonging to \(Z\) are available in this case, the dataset \(Z\) is still viewed as an unlabeled dataset. This means that a prototype \(z_i \in Z\) belonging to the class \(l\) according to the labels \(y\), can be a prototype for a class \(k \ne l\).

Hyperparameter selection

Alibi exposes a cross-validation hyperparameter selection method for the radius \(\epsilon\) when the Euclidean distance is used. The method returns the \(\epsilon\) radius value that achieves the best accuracy score on a 1-KNN classification task.

from alibi.prototypes.protoselect import cv_protoselect_euclidean

cv = cv_protoselect_euclidean(trainset=(X_train, y_train), valset=(X_val, y_val), num_prototypes=num_prototypes, quantiles=(0., 0.4), preprocess_fn=preprocess_fn)

The method API is flexible and allows for various arguments to be passed such as a predefined \(\epsilon\)-grid, the number of equidistant bins, keyword arguments to the KFold split when the validation set is not provided, etc. We refer the reader to the documentation page for a full parameter description.

The best \(\epsilon\)-radius can be access through cv['best_eps']. The object also contains other metadata gathered throughout the hyperparameter search.

Data modalities

The method can be applied to any data modality by passing the preprocess_fn: Callable[[Union[list, np.ndarray]], np.ndarray] expected to return a numpy array feature representation compatible with the kernel provided.

Prototypes visualization for image modality

As proposed by Bien and Tibshirani (2012), one can visualize and understand the importance of a prototype in a 2D image scatter plot. To obtain the image size of each prototype, we fit a 1-KNN classifier on the prototypes using the feature representation provided by preprocess_fn and the Euclidean distance metric, which is consistent with our choice of kernel dissimilarity. The size of each prototype is proportional to the logarithm of the number of assigned training instances correctly classified according to the 1-KNN classifier. Thus, the larger the image, the more important the prototype is.

ProtoSelect ImageNet

Prototypes of a subsampled ImageNet dataset containing 10 classes using a ResNet50 pretrained feature extractor.

import umap from alibi.prototypes import visualize_prototypes

define 2D reducer

reducer = umap.UMAP(random_state=26) reducer = reducer.fit(preprocess_fn(X_train))

display prototypes in 2D

visualize_image_prototypes(summary=summary, trainset=(X_train, y_train), reducer=reducer.transform, preprocess_fn=preprocess_fn)

Here we used a UMAP 2D reducer, but any other dimensionality reduction method will do. The visualize_image_prototypes method exposes other arguments to control how the images will be displayed. We refer the reader to the documentation page for further details.

Examples

Tabular and image datasets