GitHub - salesforce/ALBEF: Code for ALBEF: a new vision-language pre-training method (original) (raw)
Align before Fuse: Vision and Language Representation Learning with Momentum Distillation, NeurIPS 2021 Spotlight (Salesforce Research).
Announcement: ALBEF is now officially integrated into LAVIS - a one-stop library for language-and-vision research and applications!
This is the official PyTorch implementation of the ALBEF paper [Blog]. This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k, and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released.
Requirements:
- pytorch 1.8.0
- transformers 4.8.1
- timm 0.4.9
Download:
- Pre-trained checkpoint [14M] / [4M]
- Dataset json files for downstream tasks
- Dataset json files for pre-training (the image paths in each json file need to be changed to your own directory)
- Finetuned checkpoint for retrieval on MSCOCO
- Finetuned checkpoint for retrieval on Flickr30k
- Finetuned checkpoint for VQA
- Finetuned checkpoint for visual grounding on RefCOCO+
Visualization:
We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text. Here is an example visualization using the visual grounding checkpoint.
Pre-training on custom datasets:
- Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
- In configs/Pretrain.yaml, set the paths for the json files.
- Pre-train the model using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain
Image-Text Retrieval:
- Download MSCOCO or Flickr30k datasets from the original websites.
- Download and extract the provided dataset json files.
- In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path.
- Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py
--config ./configs/Retrieval_flickr.yaml
--output_dir output/Retrieval_flickr
--checkpoint [Pretrained checkpoint]
VQA:
- Download VQA v2 dataset and Visual Genome dataset from the original websites.
- Download and extract the provided dataset json files.
- In configs/VQA.yaml, set the paths for the json files and the image paths.
- Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py
--config ./configs/VQA.yaml
--output_dir output/vqa
--checkpoint [Pretrained checkpoint]
- Evaluate the result using the official evaluation server.
Visual Entailment:
- Download SNLI-VE dataset from the original website.
- Download and extract the provided dataset json files.
- In configs/VE.yaml, set the paths for the json files and the image path.
- Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py
--config ./configs/VE.yaml
--output_dir output/VE
--checkpoint [Pretrained checkpoint]
Visual Grounding on RefCOCO+:
- Download MSCOCO dataset from the original website.
- Download and extract the provided dataset json files.
- In configs/Grounding.yaml, set the paths for the json files and the image path.
- Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py
--config ./configs/Grounding.yaml
--output_dir output/RefCOCO
--gradcam_mode itm \
--block_num 8
--checkpoint [Pretrained checkpoint]
NLVR2:
NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py
--config ./configs/NLVR_pretrain.yaml
--output_dir output/NLVR_pretrain
--checkpoint [Pretrained checkpoint]
We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps.
- Download NLVR2 dataset from the original website.
- Download and extract the provided dataset json files.
- In configs/NLVR.yaml, set the paths for the json files and the image path.
- Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py
--config ./configs/NLVR.yaml
--output_dir output/NLVR
--checkpoint [TA pretrained checkpoint]
Citation
If you find this code to be useful for your research, please consider citing.
@inproceedings{ALBEF, title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation}, author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi}, year={2021}, booktitle={NeurIPS}, }