GitHub - csinva/transformation-importance: Using / reproducing TRIM from the paper "Transformation Importance with Applications to Cosmology" 🌌 (ICLR Workshop 2020) (original) (raw)

Official code for using / reproducing TRIM from the paper Transformation Importance with Applications to Cosmology (ICLR 2020 Workshop). This code shows examples and provides useful wrappers for calculating importance in a transformed feature space.

This repo is actively maintained. For any questions please file an issue.

trim

examples/documentation

Attribution to different scales in cosmological images Fake news attribution to different topics
Attribution to different NMF components in MNIST classification Attribution to different frequencies in audio classification

sample usage

import torch import torch.nn as nn from trim import TrimModel from functools import partial

setup a trim model

model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1)) # orig model transform = partial(torch.rfft, signal_ndim=1, onesided=False) # fft inv_transform = partial(torch.irfft, signal_ndim=1, onesided=False) # inverse fft model_trim = TrimModel(model=model, inv_transform=inv_transform) # trim model

get a data point

x = torch.randn(1, 10) s = transform(x)

can now use any attribution method on the trim model

get (input_x_gradient) attribution in the fft space

s.requires_grad = True model_trim(s).backward() input_x_gradient = s.grad * s

reference

@article{singh2020transformation, title={Transformation Importance with Applications to Cosmology}, author={Singh, Chandan and Ha, Wooseok and Lanusse, Francois, and Boehm, Vanessa, and Liu, Jia and Yu, Bin}, journal={arXiv preprint arXiv:2003.01926}, year={2020}, url={https://arxiv.org/abs/2003.01926}, }