GitHub - GuoQiushan/EGC (original) (raw)

EGC: Image Generation and Classification via a Single Energy-Based Model (ICCV 2023)

avatar

Download pre-trained models

We have some released checkpoints for the main models in the paper.

Here are the download links for each model checkpoint:

More checkpoints and training scripts will be released soon.

Sampling from pre-trained models

To sample from 256x256 ImageNet EGC model, you can use the run_imagenet_egc_latent_sample_cond.sh. Here, we provide flags for sampling from all of these models.

For this example, we will generate 50000 samples with batch size 8 and 100 ddim steps. Feel free to change these values.

OPT="--batch_size 8 --num_samples 50000 --use_ddim True --timestep_respacing ddim100"

The classifier guidiance scale --classifier_scale is be recommended to set as 6.0 to reproduce the FID score.

Image Classification with pre-trained models

To test the image classification performance of 256x256 ImageNet EGC model, you can use the run_imagenet_egc_eval_cls.sh.

For this example, you can run the following cmd:

./run_imagenet_egc_eval_cls.sh <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mi>O</mi><mi>G</mi><mi>D</mi><mi>I</mi><mi>R</mi><mn>110127.0.0.1</mn></mrow><annotation encoding="application/x-tex">LOGDIR 1 1 0 127.0.0.1 </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal">L</span><span class="mord mathnormal">OG</span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span><span class="mord">110127.0.0.1</span></span></span></span>CKPT_PATH --val_data_dir=./data/imagenet256_features_val

Training models

To reproduce the ImageNet result, you can run:

./run_imagenet_egc_train.sh <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mi>O</mi><mi>G</mi><mi>D</mi><mi>I</mi><mi>R</mi></mrow><annotation encoding="application/x-tex">LOGDIR </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal">L</span><span class="mord mathnormal">OG</span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span></span></span></span>GPU_NUM <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi>O</mi><mi>D</mi><msub><mi>E</mi><mi>N</mi></msub><mi>U</mi><mi>N</mi></mrow><annotation encoding="application/x-tex">NODE_NUN </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">NO</span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:-0.0576em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.10903em;">U</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span>RANK $MASTER_ADDR

Make sure that the GPU_NUM * NODE_NUN * batch_size = 512 and batch_size = cls_batch_size in the shell script. You may change microbatch to reduce the memory cost.

Prepare ImageNet data

The ImageNet-1k dataset should be organized as following:

EGC
├── data
│   ├── imagenet
│   │   ├── train/
│   │   ├── val/
│   ├── imagenet256_features/
│   ├── imagenet256_features_val/

Besides, download the autoencoder_kl.pth to the EGC folder.

Convert the raw image to latent space, using python scripts/extract_feat.py ./data/imagenet/train ./autoencoder_kl.pth ./data/imagenet/imagenet256_features.

Convert the raw image to latent space, using python scripts/extract_feat.py ./data/imagenet/val ./autoencoder_kl.pth ./data/imagenet/imagenet256_features_val.

Cite

If you find EGC useful for your work, please cite:

@article{guo2023egc, title={EGC: Image Generation and Classification via a Single Energy-Based Model}, author={Guo, Qiushan and Ma, Chuofan and Jiang, Yi and Yuan, Zehuan and Yu, Yizhou and Luo, Ping}, journal={arXiv preprint arXiv:2304.02012}, year={2023} }

Acknowledgement

This repository is based on openai/guided-diffusion, with modifications for energy-based training and sampling and architecture improvements. Thanks for their wonderful works.