GitHub - Vchitect/Latte: [TMLR 2025] Latte: Latent Diffusion Transformer for Video Generation. (original) (raw)

Latte: Latent Diffusion Transformer for Video Generation

Official PyTorch Implementation

arXiv Project Page HF Demo Hugging Face Spaces slack badge

Static Badge Static Badge

This repo contains the PyTorch model definitions, pre-trained weights, and training/sampling/evaluation codes for our paper Latte: Latent Diffusion Transformer for Video Generation.

Latte: Latent Diffusion Transformer for Video Generation
Xin Ma, Yaohui Wang*, Xinyuan Chen, Gengyun Jia, Ziwei Liu, Yuan-Fang Li, Cunjian Chen, Yu Qiao(*Corresponding Author & Project Lead)

latte-1.mp4

News

# Please update the version of diffusers at leaset to 0.30.0
from diffusers import LattePipeline
from diffusers.models import AutoencoderKLTemporalDecoder
from torchvision.utils import save_image
import torch
import imageio

torch.manual_seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
video_length = 16 # 1 (text-to-image) or 16 (text-to-video)
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to(device)

# Using temporal decoder of VAE
vae = AutoencoderKLTemporalDecoder.from_pretrained("maxin-cn/Latte-1", subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
pipe.vae = vae

prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()

Setup

First, download and set up the repo:

git clone https://github.com/Vchitect/Latte cd Latte

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml conda activate latte

Sampling

You can sample from our pre-trained Latte models with sample.py. Weights for our pre-trained Latte model can be found here. The script has various arguments to adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our model on FaceForensics, you can use:

or if you want to sample hundreds of videos, you can use the following script with Pytorch DDP:

If you want to try generating videos from text, just run bash sample/t2v.sh. All related checkpoints will download automatically.

If you would like to measure the quantitative metrics of your generated results, please refer to here.

Training

We provide a training script for Latte in train.py. The structure of the datasets can be found here. This script can be used to train class-conditional and unconditional Latte models. To launch Latte (256x256) training with N GPUs on the FaceForensics dataset :

torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ffs/ffs_train.yaml

or If you have a cluster that uses slurm, you can also train Latte's model using the following scripts:

sbatch slurm_scripts/ffs.slurm

We also provide the video-image joint training scripts train_with_img.py. Similar to train.py scripts, these scripts can be also used to train class-conditional and unconditional Latte models. For example, if you want to train the Latte model on the FaceForensics dataset, you can use:

torchrun --nnodes=1 --nproc_per_node=N train_with_img.py --config ./configs/ffs/ffs_img_train.yaml

If you are familiar with PyTorch Lightning, you can also use the training script train_pl.py and train_with_img_pl.py provided by @zhang.haojie,

python train_pl.py --config ./configs/ffs/ffs_train.yaml

or

python train_with_img_pl.py --config ./configs/ffs/ffs_img_train.yaml

This script automatically detects available GPUs and uses distributed training.

Contact Us

Yaohui Wang: wangyaohui@pjlab.org.cn Xin Ma: xin.ma1@monash.edu

Citation

If you find this work useful for your research, please consider citing it.

@article{ma2025latte, title={Latte: Latent Diffusion Transformer for Video Generation}, author={Ma, Xin and Wang, Yaohui and Chen, Xinyuan and Jia, Gengyun and Liu, Ziwei and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu}, journal={Transactions on Machine Learning Research}, year={2025} }

Acknowledgments

Latte has been greatly inspired by the following amazing works and teams: DiT and PixArt-α, we thank all the contributors for open-sourcing.

License

The code and model weights are licensed under LICENSE.