GitHub - mit-han-lab/lite-transformer: [ICLR 2020] Lite Transformer with Long-Short Range Attention (original) (raw)

Lite Transformer

paper | website | slides

@inproceedings{Wu2020LiteTransformer,
  title={Lite Transformer with Long-Short Range Attention},
  author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}

Overview

overview

How to Use

Prerequisite

Installation

  1. Codebase
    To install fairseq from source and develop locally:
  2. Costumized Modules
    We also need to build the lightconv and dynamicconv for GPU support.
    Lightconv_layer
    cd fairseq/modules/lightconv_layer
    python cuda_function_gen.py
    python setup.py install
    Dynamicconv_layer
    cd fairseq/modules/dynamicconv_layer
    python cuda_function_gen.py
    python setup.py install

Data Preparation

IWSLT'14 De-En

We follow the data preparation in fairseq. To download and preprocess the data, one can run

bash configs/iwslt14.de-en/prepare.sh

WMT'14 En-Fr

We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

bash configs/wmt14.en-fr/prepare.sh

WMT'16 En-De

We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run

bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]

WIKITEXT-103

As the language model task has many additional codes, we place it in another branch: language-model. We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

git checkout language-model bash configs/wikitext-103/prepare.sh

Testing

For example, to test the models on WMT'14 En-Fr, one can run

configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]

For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run

configs/wmt14.en-fr/test.sh embed496/ 0 test

We provide several pretrained models at the bottom. You can download the model and extract the file by

Training

We provided several examples to train Lite Transformer with this repo:

To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run

python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml

To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32

In general, to train a model, one can run

python train.py [path to the data binary] --configs [path to config file] [override options]

Note that --update-freq should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).

Distributed Training (optional)

To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.

On host1

python -m torch.distributed.launch
--nproc_per_node=8
--nnodes=2 --node_rank=0
--master_addr=host1 --master_port=8080
train.py data/binary/wmt14_en_fr
--configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
--distributed-no-spawn
--update-freq 8

On host2

python -m torch.distributed.launch
--nproc_per_node=8
--nnodes=2 --node_rank=1
--master_addr=host1 --master_port=8080
train.py data/binary/wmt14_en_fr
--configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
--distributed-no-spawn
--update-freq 8

Models

We provide the checkpoints for our Lite Transformer reported in the paper:

Dataset #Mult-Adds Test Score Model and Test Set
WMT'14 En-Fr 90M 35.3 download
360M 39.1 download
527M 39.6 download
WMT'16 En-De 90M 22.5 download
360M 25.6 download
527M 26.5 download
CNN / DailyMail 800M 38.3 (R-L) download
WIKITEXT-103 1147M 22.2 (PPL) download