GitHub - LTH14/mage: A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis (original) (raw)
This is a PyTorch/GPU re-implementation of the paperMAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis (to appear in CVPR 2023):
@article{li2022mage,
title={MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis},
author={Li, Tianhong and Chang, Huiwen and Mishra, Shlok Kumar and Zhang, Han and Katabi, Dina and Krishnan, Dilip},
journal={arXiv preprint arXiv:2211.09117},
year={2022}
}
MAGE is a unified framework for both generative modeling and representation learning, achieving SOTA results in both class-unconditional image generation and linear probing on ImageNet-1K.
A large portion of codes in this repo is based on MAE and VQGAN. The original implementation was in JAX/TPU.
Preparation
Dataset
Download ImageNet dataset, and place it in your IMAGENET_DIR.
Installation
A suitable conda environment named mage can be created and activated with:
conda env create -f environment.yaml
conda activate mage
Download the code
git clone https://github.com/LTH14/mage.git
cd mage
Use this linkto download the pre-trained VQGAN tokenzier and put it in the mage directory.
Usage
Pre-training
To pre-train a MAGE ViT-B model with 4096 batch size using 8 servers with 8 V100 GPUs per server:
python -m torch.distributed.launch --node_rank=0 --nproc_per_node=8 --nnodes=8 \
--master_addr="${MASTER_SERVER_ADDRESS}" --master_port=12344 \
main_pretrain.py \
--batch_size 64 \
--model mage_vit_base_patch16 \
--mask_ratio_min 0.5 --mask_ratio_max 1.0 \
--mask_ratio_mu 0.55 --mask_ratio_std 0.25 \
--epochs 1600 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214
The following table provides the performance and weights of the pre-trained checkpoints used in the paper, converted from JAX/TPU to PT/GPU:
| ViT-Base | ViT-Large | |
|---|---|---|
| Checkpoint | Google Drive | Google Drive |
| Class-unconditional Generation FID | 11.1 | 9.10 |
| Class-unconditional Generation IS | 81.2 | 105.1 |
| Linear Probing Top-1 Accuracy | 74.7% | 78.9% |
| Fine-tuning Top-1 Accuracy | 82.5% Checkpoint | 83.9% Checkpoint |
Linear Probing
To perform linear probing on pre-trained MAGE model using 4 servers with 8 V100 GPUs per server:
python -m torch.distributed.launch --node_rank=0 --nproc_per_node=8 --nnodes=4 \
--master_addr="${MASTER_SERVER_ADDRESS}" --master_port=12344 \
main_linprobe.py \
--batch_size 128 \
--model vit_base_patch16 \
--global_pool \
--finetune ${PRETRAIN_CHKPT} \
--epochs 90 \
--blr 0.1 \
--weight_decay 0.0 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_eval --dist_url tcp://${MASTER_SERVER_ADDRESS}:6311
For ViT-L, set --blr 0.05.
Fine-tuning
To perform fine-tuning with pre-trained ViT-B model using 4 servers with 8 V100 GPUs per server:
python -m torch.distributed.launch --node_rank=0 --nproc_per_node=8 --nnodes=4 \
--master_addr="${MASTER_SERVER_ADDRESS}" --master_port=12344 \
main_finetune.py \
--batch_size 32 \
--model vit_base_patch16 \
--global_pool \
--finetune ${PRETRAIN_CHKPT} \
--epochs 100 \
--blr 2.5e-4 --layer_decay 0.65 --interpolation bicubic \
--weight_decay 0.05 --drop_path 0.1 --reprob 0 --mixup 0.8 --cutmix 1.0 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_eval --dist_url tcp://${MASTER_SERVER_ADDRESS}:6311
For ViT-L, set --epochs 50 --layer_decay 0.75 --drop_path 0.2.
Class Unconditional Generation
To perform class unconditional generation with pre-trained MAGE model using a single V100 GPU:
python gen_img_uncond.py --temp 6.0 --num_iter 20 \
--ckpt ${PRETRAIN_CHKPT} --batch_size 32 --num_images 50000 \
--model mage_vit_base_patch16 --output_dir ${OUTPUT_DIR}
To quantitatively evaluate FID/IS, please first generate 256x256 ImageNet validation images using
python prepare_imgnet_val.py --data_path <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mrow><mi>I</mi><mi>M</mi><mi>A</mi><mi>G</mi><mi>E</mi><mi>N</mi><mi>E</mi><msub><mi>T</mi><mi>D</mi></msub><mi>I</mi><mi>R</mi></mrow><mo>−</mo><mo>−</mo><mi>o</mi><mi>u</mi><mi>t</mi><mi>p</mi><mi>u</mi><msub><mi>t</mi><mi>d</mi></msub><mi>i</mi><mi>r</mi></mrow><annotation encoding="application/x-tex">{IMAGENET_DIR} --output_dir </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord mathnormal">A</span><span class="mord mathnormal" style="margin-right:0.05764em;">GENE</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:-0.1389em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">D</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.854em;vertical-align:-0.1944em;"></span><span class="mord">−</span><span class="mord mathnormal">o</span><span class="mord mathnormal">u</span><span class="mord mathnormal">tp</span><span class="mord mathnormal">u</span><span class="mord"><span class="mord mathnormal">t</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal">i</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span></span></span></span>{OUTPUT_DIR}
Then install the torch-fidelitypackage by
pip install torch-fidelity
Then use the above package to evaluate FID/IS of the images generated by our models against 256x256 ImageNet validation images by
fidelity --gpu 0 --isc --fid --input1 <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mrow><mi>G</mi><mi>E</mi><mi>N</mi><mi>E</mi><mi>R</mi><mi>A</mi><mi>T</mi><mi>E</mi><msub><mi>D</mi><mi>I</mi></msub><mi>M</mi><mi>A</mi><mi>G</mi><mi>E</mi><msub><mi>S</mi><mi>D</mi></msub><mi>I</mi><mi>R</mi></mrow><mo>−</mo><mo>−</mo><mi>i</mi><mi>n</mi><mi>p</mi><mi>u</mi><mi>t</mi><mn>2</mn></mrow><annotation encoding="application/x-tex">{GENERATED_IMAGES_DIR} --input2 </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.00773em;">GENER</span><span class="mord mathnormal">A</span><span class="mord mathnormal" style="margin-right:0.05764em;">TE</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.07847em;">I</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord mathnormal">A</span><span class="mord mathnormal" style="margin-right:0.05764em;">GE</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:-0.0576em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">D</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.854em;vertical-align:-0.1944em;"></span><span class="mord">−</span><span class="mord mathnormal">in</span><span class="mord mathnormal">p</span><span class="mord mathnormal">u</span><span class="mord mathnormal">t</span><span class="mord">2</span></span></span></span>{IMAGENET256X256_DIR}
Here are some examples of our class-unconditional generation:
MAGE-C
Here we provide the pre-trained MAGE-C checkpoints converted from JAX/TPU to PT/GPU:ViT-B,ViT-L. PyTorch training script coming soon.
Contact
If you have any questions, feel free to contact me through email (tianhong@mit.edu). Enjoy!


