GitHub - atuannguyen/DIRT (original) (raw)
Domain Invariant Representation Learning with Domain Density Transformations
This repository is the official implementation for the NeurIPS 2021 paper Domain Domain Invariant Representation Learning with Domain Density Transformations.
Please consider citing our paper as
@article{nguyen2021domain,
title={Domain invariant representation learning with domain density transformations},
author={Nguyen, A. Tuan and Tran, Toan and Gal, Yarin and Baydin, Atilim Gunes},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
Credits:
Code for StarGan is modified from https://github.com/yunjey/stargan
Code for RotatedMnist DataLoader is modified from https://github.com/AMLab-Amsterdam/DIVA
Code for other Dataset DataLoader is modified from https://github.com/facebookresearch/DomainBed
Requirements:
python3, pytorch 1.7.0 or higher, torchvision 0.8.0 or higher
How to run:
- To run the experiment for Rotated MNIST: For example, target domain 0 and seed 0
cd domain_gen_rotatedmnist
CUDA_VISIBLE_DEVICES=0 python train_stargan.py --target_domain 0 # To run the StarGAN model, although we already provide the checkpoint so you might skip this
CUDA_VISIBLE_DEVICES=0 python -u train.py --model=dirt --seed=0 --epochs=500 --target_domain=0
- To run the experiment for PACS: For example, for PACS with ResNet, target domain 0 and seed 0
# Change the --data_dir flag to your data directory
cd domain_gen
CUDA_VISIBLE_DEVICES=0 python train_stargan.py --dataset PACS --data_dir ../data/ --target_domain 0 # To run the StarGAN model, we provided checkpoint for PACS
CUDA_VISIBLE_DEVICES=0 python -u main.py --dataset PACS --data_dir ../data --model=dirt --base resnet18 --seed=0 --target_domain=0