LIME Tabular Explainer via XAI — Contextual AI documentation (original) (raw)
This tutorial demonstrates how to generate explanations using LIME’s tabular explainer implemented by the Contextual AI library.
At a high level, explanations can be obtained from any Contextual AI explanation algorithm in 3 steps:
- Create an explainer via the
ExplainerFactory
class, which serves as the primary interface between the user and all Contextual AI-supported explanation algorithms - Build the explainer by calling the
build_explainer
method (which is implemented by any Contextual AI explanation algorithm) and providing arguments that are specific to that algorithm - Get explanations for some data instance by calling the
explain_instance
method (which is also common among all algorithms) and provoding arguments that are specific to that algorithm
Step 1: Import libraries¶
xai.explainer.ExplainerFactory
is the main class that users of Contextual AI interact with. xai
contains some constants that are used to instantiate an AbstractExplainer
object.
Some auxiliary imports for the tutorial
import sys import random import numpy as np from pprint import pprint from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier
Set seed for reproducibility
np.random.seed(123456)
Set the path so that we can import ExplainerFactory
sys.path.append('../../')
Main Contextual AI imports
import xai from xai.explainer import ExplainerFactory
Step 2: Train a model on a sample dataset¶
We train a sample RandomForestClassifier
model on the Wisconsin breast cancer dataset, a sample binary classification problem that is provided by scikit-learn (details can be found here).
Load the dataset and prepare training and test sets
raw_data = datasets.load_breast_cancer() X, y = raw_data['data'], raw_data['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
Instantiate a classifier, train, and evaluate on test set
clf = RandomForestClassifier() clf.fit(X_train, y_train) clf.score(X_test, y_test)
/Users/i330688/venv_xai/lib/python3.6/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22. "10 in version 0.20 to 100 in 0.22.", FutureWarning)
Step 3: Instantiate the explainer¶
This is where we instantiate the Contextual AI explainer. This ExplainerFactory
class is in charge of loading a particular explanation algorithm. The user is required to provide one argument - the domain
, which indicates the domain of the training data (e.g. tabular
or text
). The available domains can be found in xai.DOMAIN
. Users can also select a particular explainer algorithm by providing the algorithm’s name (registered in xai.ALG
) to the algorithm
parameter. If this argument is not provided, the ExplainerFactory.get_explainer
method defaults to a pre-set algorithm for that domain which can be found in xai/explainer/config.py
.
We want to load the LimeTabularExplainer
, so we provide xai.DOMAIN.TABULAR
as the argument to domain
and xai.ALG.LIME
as the argument to algorithm
. Note that xai.ALG.LIME
is the default tabular explanation algorithm; hence this also works:
explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TABULAR)
Instantiate LimeTabularExplainer via the Explainer interface
explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TABULAR, algorithm=xai.ALG.LIME)
Step 4: Build the explainer¶
build_explainer
calls the explanation algorithms initialization routine, which can include things like setting parameters or a pre-training loop. The LimeTabularExplainer
requires the following parameters:
- training_data (np.ndarray): 2d Numpy array representing the training data (or some representative subset) (required)
- mode (str): Whether the problem is ‘classification’ or ‘regression’ (required)
- predict_fn (function): A function that wraps the target model’s prediction function - it takes in a 1D numpy array and outputs a vector of probabilities which should sum to 1 (required)
Here are some other optional parameters: * training_labels (list): Training labels, which can be used by the continuous feature discretizer * feature_names (list): The names of the columns of the training data * categorical_features (list): Integer list indicating the indices of categorical features * dict_categorical_mapping (dict): Mapping of integer index of categorical feature (same as from categorical_features) to a list of values for that column. So dict_categorical_mapping[x][y] is the yth value of column x. * kernel_width (float): Width of the exponential kernel used in the LIME loss function * verbose (bool): Control verbosity. If true, local prediction values of the LIME model are printed * class_names (list): Class names (positional index corresponding to class index) * feature_selection (str): Feature selection method. See original docs for more details * discretize_continuous (True): Whether to discretize non-categorical features * discretizer (str): Type of discretization. See original docs for more details * sample_around_instance (True): if True, will sample continuous features in perturbed samples from a normal centered at the instance being explained. Otherwise, the normal is centered on the mean of the feature data. * random_state (int): The random seed to generate random numbers during training
In this particular example, we pass the RandomForestClassifier
’s predict_proba
function to predict_fn
and get explanations for the two classes.
explainer.build_explainer( training_data=X_train, training_labels=y_train, mode=xai.MODE.CLASSIFICATION, predict_fn=clf.predict_proba, feature_names=raw_data['feature_names'], class_names=list(raw_data['target_names']) )
Step 5: Generate some explanations¶
Once we build the explainer, we can start generating some explanations via the explain_instance
method. The LimeTabularExplainer
expects several things, like: * instance (np.ndarray): A 1D numpy array corresponding to a row/single example (required)
You can also pass the following:
- labels (list): The list of class indexes to produce explanations for
- top_labels (int): If not None, this overwrites labels and the explainer instead produces explanations for the top k classes
- num_features (int): Number of features to include in an explanation
- num_samples (int): The number of perturbed samples to train the LIME model with
- distance_metric (str): The distance metric to use for weighting the loss function
We restrict explanations to 10 features (meaning only 10 features will have scores attached to them). The output of explain_instance
is a dictionary that maps each class to two things - the confidence of model and a list of explanations.
exp = explainer.explain_instance( instance=X_test[0], top_labels=2, num_features=5)
pprint(exp)
{0: {'explanation': [{'feature': 'worst perimeter <= 83.79', 'score': -0.10193695487658752}, {'feature': 'worst area <= 509.25', 'score': -0.09601666088375639}, {'feature': 'worst radius <= 12.93', 'score': -0.06025582708518221}, {'feature': 'mean area <= 419.25', 'score': -0.056302517885391166}, {'feature': 'worst texture <= 21.41', 'score': -0.05509499962470648}], 'prediction': 0.0}, 1: {'explanation': [{'feature': 'worst perimeter <= 83.79', 'score': 0.10193695487658752}, {'feature': 'worst area <= 509.25', 'score': 0.0960166608837564}, {'feature': 'worst radius <= 12.93', 'score': 0.06025582708518222}, {'feature': 'mean area <= 419.25', 'score': 0.05630251788539119}, {'feature': 'worst texture <= 21.41', 'score': 0.05509499962470641}], 'prediction': 1.0}}
Step 6: Save and load the explainer¶
Finally, every Contextual AI explainer supports saving and loading functions.
Save the explainer somewhere
explainer.save_explainer('artefacts/lime_tabular.pkl')
Load the saved explainer in a new Explainer instance
new_explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TABULAR, algorithm=xai.ALG.LIME) new_explainer.load_explainer('artefacts/lime_tabular.pkl')
exp = new_explainer.explain_instance( instance=X_test[0], top_labels=2, num_features=5)
pprint(exp)
{0: {'explanation': [{'feature': 'worst perimeter <= 83.79', 'score': -0.09985606175737251}, {'feature': 'worst area <= 509.25', 'score': -0.08623375147255567}, {'feature': 'mean area <= 419.25', 'score': -0.07671371631709668}, {'feature': 'worst radius <= 12.93', 'score': -0.06861610584095608}, {'feature': 'worst texture <= 21.41', 'score': -0.05078617133441289}], 'prediction': 0.0}, 1: {'explanation': [{'feature': 'worst perimeter <= 83.79', 'score': 0.09985606175737251}, {'feature': 'worst area <= 509.25', 'score': 0.08623375147255567}, {'feature': 'mean area <= 419.25', 'score': 0.0767137163170967}, {'feature': 'worst radius <= 12.93', 'score': 0.0686161058409561}, {'feature': 'worst texture <= 21.41', 'score': 0.05078617133441288}], 'prediction': 1.0}}