GitHub - tensorflow/model-remediation: Model Remediation is a library that provides solutions for machine learning practitioners working to create and train models in a way that reduces or eliminates user harm resulting from underlying performance biases. (original) (raw)

TensorFlow Model Remediation

TensorFlow Model Remediation is a library that provides solutions for machine learning practitioners working to create and train models in a way that reduces or eliminates user harm resulting from underlying performance biases.

PyPI version

Tutorial

Overview

Installation

You can install the package from pip:

$ pip install tensorflow-model-remediation

Note: Make sure you are using TensorFlow 2.x.

Documentation

This library contains a collection of machine learning remediation techniques for addressing potential bias in a model.

Currently TensorFlow Model Remediation contains the below techniques:

We recommend starting with theoverview guide to get an idea of TensorFlow Model Remediation. Next try one of our interactive guides like the

MinDiff tutorial notebook.

Counterfactual tutorial notebook.

import tensorflow_model_remediation as tfmr

import tensorflow as tf

Start by defining a Keras model.

original_model = ...

Next pick the remediation technique you'd like to use. For example, a

MinDiff implementation might look like the below:

Set the MinDiff weight and choose a loss.

min_diff_loss = tfmr.min_diff.losses.MMDLoss()

min_diff_weight = 1.0 # Hyperparamater to be tuned.

Create a MinDiff model.

min_diff_model = tfmr.min_diff.keras.MinDiffModel(

original_model, min_diff_loss, min_diff_weight)

Compile the MinDiff model as you normally would do with the original model.

min_diff_model.compile(...)

Create a MinDiff Dataset and train the min_diff_model on it.

min_diff_model.fit(min_diff_dataset, ...)

Disclaimers

If you're interested in learning more about responsible AI practices, including

fairness, please see Google AI's Responsible AI Practices.

tensorflow/modelremediation is Apache 2.0 licensed. See theLICENSE file.