GitHub - mit-han-lab/lite-transformer: [ICLR 2020] Lite Transformer with Long-Short Range Attention (original) (raw)
Lite Transformer
paper | website | slides
@inproceedings{Wu2020LiteTransformer,
title={Lite Transformer with Long-Short Range Attention},
author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
booktitle={International Conference on Learning Representations (ICLR)},
year={2020}
}
Overview
How to Use
Prerequisite
- Python version >= 3.6
- PyTorch version >= 1.0.0
- configargparse >= 0.14
- For training new models, you'll also need an NVIDIA GPU and NCCL
Installation
- Codebase
To install fairseq from source and develop locally: - Costumized Modules
We also need to build thelightconv
anddynamicconv
for GPU support.
Lightconv_layer
cd fairseq/modules/lightconv_layer
python cuda_function_gen.py
python setup.py install
Dynamicconv_layer
cd fairseq/modules/dynamicconv_layer
python cuda_function_gen.py
python setup.py install
Data Preparation
IWSLT'14 De-En
We follow the data preparation in fairseq. To download and preprocess the data, one can run
bash configs/iwslt14.de-en/prepare.sh
WMT'14 En-Fr
We follow the data pre-processing in fairseq. To download and preprocess the data, one can run
bash configs/wmt14.en-fr/prepare.sh
WMT'16 En-De
We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run
bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]
WIKITEXT-103
As the language model task has many additional codes, we place it in another branch: language-model
. We follow the data pre-processing in fairseq. To download and preprocess the data, one can run
git checkout language-model bash configs/wikitext-103/prepare.sh
Testing
For example, to test the models on WMT'14 En-Fr, one can run
configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]
For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run
configs/wmt14.en-fr/test.sh embed496/ 0 test
We provide several pretrained models at the bottom. You can download the model and extract the file by
Training
We provided several examples to train Lite Transformer with this repo:
To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
In general, to train a model, one can run
python train.py [path to the data binary] --configs [path to config file] [override options]
Note that --update-freq
should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).
Distributed Training (optional)
To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.
On host1
python -m torch.distributed.launch
--nproc_per_node=8
--nnodes=2 --node_rank=0
--master_addr=host1 --master_port=8080
train.py data/binary/wmt14_en_fr
--configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
--distributed-no-spawn
--update-freq 8
On host2
python -m torch.distributed.launch
--nproc_per_node=8
--nnodes=2 --node_rank=1
--master_addr=host1 --master_port=8080
train.py data/binary/wmt14_en_fr
--configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
--distributed-no-spawn
--update-freq 8
Models
We provide the checkpoints for our Lite Transformer reported in the paper:
Dataset | #Mult-Adds | Test Score | Model and Test Set |
---|---|---|---|
WMT'14 En-Fr | 90M | 35.3 | download |
360M | 39.1 | download | |
527M | 39.6 | download | |
WMT'16 En-De | 90M | 22.5 | download |
360M | 25.6 | download | |
527M | 26.5 | download | |
CNN / DailyMail | 800M | 38.3 (R-L) | download |
WIKITEXT-103 | 1147M | 22.2 (PPL) | download |