Releases · huggingface/diffusers (original) (raw)
v0.33.1: fix ftfy import
Diffusers 0.33.0: New Image and Video Models, Memory Optimizations, Caching Methods, Remote VAEs, New Training Scripts, and more
New Pipelines for Video Generation
Wan 2.1
Wan2.1 is a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. The model release includes 4 different model variants and three different pipelines for Text to Video, Image to Video and Video to Video.
Wan-AI/Wan2.1-T2V-1.3B-Diffusers
Wan-AI/Wan2.1-T2V-14B-Diffusers
Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
Check out the docs here to learn more.
LTX Video 0.9.5
LTX Video 0.9.5 is the updated version of the super-fast LTX Video model series. The latest model introduces additional conditioning options, such as keyframe-based animation and video extension (both forward and backward).
To support these additional conditioning inputs, we’ve introduced the LTXConditionPipeline
and LTXVideoCondition
object.
To learn more about the usage, check out the docs here.
Hunyuan Image to Video
Hunyuan utilizes a pre-trained Multimodal Large Language Model (MLLM) with a Decoder-Only architecture as the text encoder. The input image is processed by the MLLM to generate semantic image tokens. These tokens are then concatenated with the video latent tokens, enabling comprehensive full-attention computation across the combined data and seamlessly integrating information from both the image and its associated caption.
To learn more, check out the docs here.
Others
- EasyAnimateV5 (thanks to @bubbliiiing for contributing this in this PR)
- ConsisID (thanks to @SHYuanBest for contributing this in this PR)
New Pipelines for Image Generation
Sana-Sprint
SANA-Sprint is an efficient diffusion model for ultra-fast text-to-image generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4, rivaling the quality of models like Flux.
Shoutout to @lawrence-cj for their help and guidance on this PR.
Check out the pipeline docs of SANA-Sprint to learn more.
Lumina2
Lumina-Image-2.0 is a 2B parameter flow-based diffusion transformer for text-to-image generation released under the Apache 2.0 license.
Check out the docs to learn more. Thanks to @zhuole1025 for contributing this through this PR.
One can also LoRA fine-tune Lumina2, taking advantage of its Apach2.0 licensing. Check out the guide for more details.
Omnigen
OmniGen is a unified image generation model that can handle multiple tasks including text-to-image, image editing, subject-driven generation, and various computer vision tasks within a single framework. The model consists of a VAE, and a single transformer based on Phi-3 that handles text and image encoding as well as the diffusion process.
Check out the docs to learn more about OmniGen. Thanks to @staoxiao for contributing OmniGen in this PR.
Others
- CogView4 (thanks to @zRzRzRzRzRzRzR for contributing CogView4 in this PR)
New Memory Optimizations
Layerwise Casting
PyTorch supports torch.float8_e4m3fn
and torch.float8_e5m2
as weight storage dtypes
, but they can’t be used for computation on many devices due to unimplemented kernel support.
However, you can still use these dtypes
to store model weights in FP8 precision and upcast them to a widely supported dtype such as torch.float16
or torch.bfloat16
on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. This can potentially cut down the VRAM requirements of a model by 50%.
Code
import torch from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel from diffusers.utils import export_to_video
model_id = "THUDM/CogVideoX-5b"
Load the model in bfloat16 and enable layerwise casting
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
Load the pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) pipe.to("cuda")
prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] export_to_video(video, "output.mp4", fps=8)
Group Offloading
Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either torch.nn.ModuleList
or torch.nn.Sequential
), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.
On CUDA devices, we also have the option to enable using layer prefetching with CUDA Streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed which makes inference substantially faster while still keeping VRAM requirements very low. With this, we introduce the idea of overlapping computation with data transfer.
One thing to note is that using CUDA streams can cause a considerable spike in CPU RAM usage. Please ensure that the available CPU RAM is 2 times the size of the model if you choose to set use_stream=True
. You can reduce CPU RAM usage by setting low_cpu_mem_usage=True
. This should limit the CPU RAM used to be roughly the same as the size of the model, but will introduce slight latency in the inference process.
You can also use record_stream=True
when using use_stream=True
to obtain more speedups at the expense of slightly increased memory usage.
Code
import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video
Load the pipeline
onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True )
prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") export_to_video(video, "output.mp4", fps=8)
Group offloading can also be applied to non-Diffusers models such as text encoders from the transformers
library.
Code
import torch from diffusers import CogVideoXPipeline from diffusers.hooks import apply_group_offloading from diffusers.utils import export_to_video
Load the pipeline
onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
Remote Components
Remote components are an experimental feature designed to offload memory-intensive steps of t...
v0.32.2
Fixes for Flux Single File loading, LoRA loading for 4bit BnB Flux, Hunyuan Video
This patch release
- Fixes a regression in loading Comfy UI format single file checkpoints for Flux
- Fixes a regression in loading LoRAs with bitsandbytes 4bit quantized Flux models
- Adds
unload_lora_weights
for Flux Control - Fixes a bug that prevents Hunyuan Video from running with batch size > 1
- Allow Hunyuan Video to load LoRAs created from the original repository code
All commits
- [Single File] Fix loading Flux Dev finetunes with Comfy Prefix by @DN6 in #10545
- [CI] Update HF Token on Fast GPU Model Tests by @DN6 #10570
- [CI] Update HF Token in Fast GPU Tests by @DN6 #10568
- Fix batch > 1 in HunyuanVideo by @hlky in #10548
- Fix HunyuanVideo produces NaN on PyTorch<2.5 by @hlky in #10482
- Fix hunyuan video attention mask dim by @a-r-r-o-w in #10454
- [LoRA] Support original format loras for HunyuanVideo by @a-r-r-o-w in #10376
- [LoRA] feat: support loading loras into 4bit quantized Flux models. by @sayakpaul in #10578
- [LoRA] clean up
load_lora_into_text_encoder()
andfuse_lora()
copied from by @sayakpaul in #10495 - [LoRA] feat: support
unload_lora_weights()
for Flux Control. by @sayakpaul in #10206 - Fix Flux multiple Lora loading bug by @maxs-kan in #10388
- [LoRA] fix: lora unloading when using expanded Flux LoRAs. by @sayakpaul in #10397
v0.32.1
TorchAO Quantizer fixes
This patch release fixes a few bugs related to the TorchAO Quantizer introduced in v0.32.0.
- Importing Diffusers would raise an error in PyTorch versions lower than 2.3.0. This should no longer be a problem.
- Device Map does not work as expected when using the quantizer. We now raise an error if it is used. Support for using device maps with different quantization backends will be added in the near future.
- Quantization was not performed due to faulty logic. This is now fixed and better tested.
Refer to our documentation to learn more about how to use different quantization backends.
All commits
- make style for #10368 by @yiyixuxu in #10370
- fix test pypi installation in the release workflow by @sayakpaul in #10360
- Fix TorchAO related bugs; revert device_map changes by @a-r-r-o-w in #10371
Diffusers 0.32.0: New video pipelines, new image pipelines, new quantization backends, new training scripts, and more
hunyuan-output.mp4
This release took a while, but it has many exciting updates. It contains several new pipelines for image and video generation, new quantization backends, and more.
Going forward, to provide more transparency to the community about ongoing developments and releases in Diffusers, we will be making use of a roadmap tracker.
New Video Generation Pipelines 📹
Open video generation models are on the rise, and we’re pleased to provide comprehensive integration support for all of them. The following video pipelines are bundled in this release:
Check out this section to learn more about the fine-tuning options available for these new video models.
New Image Generation Pipelines
- SANA
- Flux Control (including Control LoRA)
- Flux Redux
- Flux Fill Inpainting / Outpainting
- Flux RF-Inversion
- SD3.5 ControlNet
- ControlNet Union XL
- SD3.5 IP Adapter
- Flux IP adapter
Important Note about the new Flux Models
We can combine the regular Flux.1 Dev LoRAs with Flux Control LoRAs, Flux Control, and Flux Fill. For example, you can enable few-steps inference with Flux Fill using:
from diffusers import FluxFillPipeline from diffusers.utils import load_image import torch
pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda")
adapter_id = "alimama-creative/FLUX.1-Turbo-Alpha" pipe.load_lora_weights(adapter_id)
image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png") mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")
image = pipe( prompt="a white paper cup", image=image, mask_image=mask, height=1632, width=1232, guidance_scale=30, num_inference_steps=8, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0) ).images[0] image.save("flux-fill-dev.png")
To learn more, check out the documentation.
Note
SANA is a small model compared to other models like Flux and Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. We support LoRA fine-tuning of SANA. Check out this section for more details.
Acknowledgements
- Shoutout to @lawrence-cj and @chenjy2003 for contributing SANA in this PR. SANA also features a Deep Compression Autoencoder, which was contributed by @lawrence-cj in this PR.
- Shoutout to @guiyrt for contributing SD3.5 IP Adapter in this PR.
New Quantization Backends
Please be aware of the following caveats:
- TorchAO quantized checkpoints cannot be serialized in
safetensors
currently. This may change in the future. - GGUF currently only supports loading pre-quantized checkpoints into models in this release. Support for saving models with GGUF quantization will be added in the future.
New training scripts
This release features many new training scripts for the community to play:
All commits
- post-release 0.31.0 by @sayakpaul in #9742
- fix bug in
require_accelerate_version_greater
by @faaany in #9746 - [Official callbacks] SDXL Controlnet CFG Cutoff by @asomoza in #9311
- [SD3-5 dreambooth lora] update model cards by @linoytsaban in #9749
- config attribute not foud error for FluxImagetoImage Pipeline for multi controlnet solved by @rshah240 in #9586
- Some minor updates to the nightly and push workflows by @sayakpaul in #9759
- [Docs] fix docstring typo in SD3 pipeline by @shenzhiy21 in #9765
- [bugfix] bugfix for npu free memory by @leisuzz in #9640
- [research_projects] add flux training script with quantization by @sayakpaul in #9754
- Add a doc for AWS Neuron in Diffusers by @JingyaHuang in #9766
- [refactor] enhance readability of flux related pipelines by @Luciennnnnnn in #9711
- Added Support of Xlabs controlnet to FluxControlNetInpaintPipeline by @SahilCarterr in #9770
- [research_projects] Update README.md to include a note about NF5 T5-xxl by @sayakpaul in #9775
- [Fix] train_dreambooth_lora_flux_advanced ValueError: unexpected save model: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> by @rootonchair in #9777
- [Fix] remove setting lr for T5 text encoder when using prodigy in flux dreambooth lora script by @biswaroop1547 in #9473
- [SD 3.5 Dreambooth LoRA] support configurable training block & layers by @linoytsaban in #9762
- [flux dreambooth lora training] make LoRA target modules configurable + small bug fix by @linoytsaban in #9646
- adds the pipeline for pixart alpha controlnet by @raulc0399 in #8857
- [core] Allegro T2V by @a-r-r-o-w in #9736
- Allegro VAE fix by @a-r-r-o-w in #9811
- [CI] add new runner for testing by @sayakpaul in #9699
- [training] fixes to the quantization training script and add AdEMAMix optimizer as an option by @sayakpaul in #9806
- [training] use the lr when using 8bit adam. by @sayakpaul in #9796
- [Tests] clean up and refactor gradient checkpointing tests by @sayakpaul in #9494
- [CI] add a big GPU marker to run memory-intensive tests separately on CI by @sayakpaul in #9691
- [LoRA] fix: lora loading when using with a device_mapped model. by @sayakpaul in #9449
- Revert "[LoRA] fix: lora loading when using with a device_mapped mode… by @yiyixuxu in #9823
- [Model Card] standardize advanced diffusion training sd15 lora by @chiral-carbon in #7613
- NPU Adaption for FLUX by @leisuzz in #9751
- Fixes EMAModel "from_pretrained" method by @SahilCarterr in #9779
- Update train_controlnet_flux.py,Fix size mismatch issue in validation by @ScilenceForest in #9679
- Handling mixed precision for dreambooth flux lora training by @icsl-Jeon in #9565
- Reduce Memory Cost in Flux Training by @leisuzz in #9829
- Add Diffusion Policy for Reinforcement Learning by @DorsaRoh in #9824
- [feat] add
load_lora_adapter()
for compatible models by @sayakpaul in #9712 - Refac training utils.py by @RogerSinghChugh in #9815
- [core] Mochi T2V by @a-r-r-o-w in #9769
- [Fix] Test of sd3 lora by @SahilCarterr in #9843
- Fix: Remove duplicated comma in distributed_inference.md by @vahidaskari in #9868
- Add new community pipeline for 'Adaptive Mask Inpainting', introduced in [ECCV2024] ComA by @jellyheadandrew in #9228
- Updated _encode_prompt_with_clip and encode_prompt in train_dreamboth_sd3 by @SahilCarterr in #9800
- [Core] introduce
controlnet
module by @sayakpaul in #8768 - [Flux] reduce explicit device transfers and typecasting in flux. by @sayakpaul in #9817
- Improve downloads of sharded variants by @DN6 in #9869
- [fix] Replaced shutil.copy with shutil.copyfile by @SahilCarterr in #9885
- Enabling gradient checkpointing in eval() mode by @MikeTkachuk in #9878
- [FIX] Fix TypeError in DreamBooth SDXL when use_dora is False by @SahilCarterr in #9879
- [Advanced LoRA v1.5] fix: gradient unscaling problem by @sayakpaul in #7018
- Revert "[Flux] reduce explicit device transfers and typecasting in flux." by @sayakpaul in #9896
- Feature IP Adapter Xformers Attention Processor by @elismasilva in #9881
- Notebooks for Community Scripts Examples by @ParagEkbote in #9905
- Fix Progress Bar Updates in SD 1.5 PAG Img2Img pipeline by @painebenjamin in #9925
- Update pipeline_flux_img2img.py by @example-git in #9928
- add de...
v0.31.0
v0.31.0: Stable Diffusion 3.5 Large, CogView3, Quantization, Training Scripts, and more
Stable Diffusion 3.5 Large
Stability AI’s latest text-to-image generation model is Stable Diffusion 3.5 Large. SD3.5 Large is the next iteration of Stable Diffusion 3. It comes with two checkpoints (both of which have 8B params):
- A regular one
- A timestep-distilled one enabling few-step inference
Make sure to fill up the form by going to the model page, and then run huggingface-cli login
before running the code below.
make sure to update diffusers
pip install -U diffusers
import torch from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16 ).to("cuda")
image = pipe( prompt="a photo of a cat holding a sign that says hello world", negative_prompt="", num_inference_steps=40, height=1024, width=1024, guidance_scale=4.5, ).images[0]
image.save("sd3_hello_world.png")
Follow the documentation to know more.
Cogview3-plus
We added a new text-to-image model, Cogview3-plus, from the THUDM team! The model is DiT-based and supports image generation from 512 to 2048px. Thanks to @zRzRzRzRzRzRzR for contributing it!
from diffusers import CogView3PlusPipeline import torch
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.float16).to("cuda")
Enable it to reduce GPU memory usage
pipe.enable_model_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling()
prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
image = pipe( prompt=prompt, guidance_scale=7.0, num_images_per_prompt=1, num_inference_steps=50, width=1024, height=1024, ).images[0]
image.save("cogview3.png")
Refer to the documentation to know more.
Quantization
We have landed native quantization support in Diffusers, starting with bitsandbytes
as its first quantization backend. With this, we hope to see large diffusion models becoming much more accessible to run on consumer hardware.
The example below shows how to run Flux.1 Dev with the NF4 data-type. Make sure you install the libraries:
pip install -Uq git+https://github.com/huggingface/transformers@main pip install -Uq bitsandbytes pip install -Uq diffusers
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel import torch
ckpt_id = "black-forest-labs/FLUX.1-dev" nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model_nf4 = FluxTransformer2DModel.from_pretrained( ckpt_id, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.bfloat16 )
Then, we use model_nf4
to instantiate the FluxPipeline
:
from diffusers import FluxPipeline
pipeline = StableDiffusion3Pipeline.from_pretrained( ckpt_id, transformer=model_nf4, torch_dtype=torch.bfloat16 ) pipeline.enable_model_cpu_offload()
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature's body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree. As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"
image = pipeline( prompt=prompt, negative_prompt="", num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512, ).images[0] image.save("whimsical.png")
Follow the documentation here to know more. Additionally, check out this Colab Notebook that runs Flux.1 Dev in an end-to-end manner with NF4 quantization.
Training scripts
We have a fresh bucket of training scripts with this release:
Video model fine-tuning can be quite expensive. So, we have worked on a repository, cogvideox-factory, which provides memory-optimized scripts to fine-tune the Cog family of models.
Misc
- We now support the loading of different kinds of Flux LoRAs, including Kohya, TheLastBen, and Xlabs.
- Loading of Xlabs Flux ControlNets is also now supported. Thanks to @Anghellia for contributing it!
All commits
- Feature flux controlnet img2img and inpaint pipeline by @ighoshsubho in #9408
- Remove CogVideoX mentions from single file docs; Test updates by @a-r-r-o-w in #9444
- set max_shard_size to None for pipeline save_pretrained by @a-r-r-o-w in #9447
- adapt masked im2im pipeline for SDXL by @noskill in #7790
- [Flux] add lora integration tests. by @sayakpaul in #9353
- [training] CogVideoX Lora by @a-r-r-o-w in #9302
- Several fixes to Flux ControlNet pipelines by @vladmandic in #9472
- [refactor] LoRA tests by @a-r-r-o-w in #9481
- [CI] fix nightly model tests by @sayakpaul in #9483
- [Cog] some minor fixes and nits by @sayakpaul in #9466
- [Tests] Reduce the model size in the lumina test by @saqlain2204 in #8985
- Fix the bug of sd3 controlnet training when using gradient checkpointing. by @pibbo88 in #9498
- [Schedulers] Add exponential sigmas / exponential noise schedule by @hlky in #9499
- Allow DDPMPipeline half precision by @sbinnee in #9222
- Add Noise Schedule/Schedule Type to Schedulers Overview documentation by @hlky in #9504
- fix bugs for sd3 controlnet training by @xduzhangjiayu in #9489
- [Doc] Fix path and and also import imageio by @LukeLIN-web in #9506
- [CI] allow faster downloads from the Hub in CI. by @sayakpaul in #9478
- a few fix for SingleFile tests by @yiyixuxu in #9522
- Add exponential sigmas to other schedulers and update docs by @hlky in #9518
- [Community Pipeline] Batched implementation of Flux with CFG by @sayakpaul in #9513
- Update community_projects.md by @lee101 in #9266
- [docs] Model sharding by @stevhliu in #9521
- update get_parameter_dtype by @yiyixuxu in #9526
- [Doc] Improved level of clarity for latents_to_rgb. by @LagPixelLOL in #9529
- [Schedulers] Add beta sigmas / beta noise schedule by @hlky in #9509
- flux controlnet fix (control_modes batch & others) by @yiyixuxu in #9507
- [Tests] Fix ChatGLMTokenizer by @asomoza in #9536
- [bug] Precedence of operations in VAE should be slicing -> tiling by @a-r-r-o-w in #9342
- [LoRA] make set_adapters() method more robust. by @sayakpaul in #9535
- [examples] add train flux-controlnet scripts in example. by @PromeAIpro in #9324
- [Tests] [LoRA] clean up the serialization stuff. by @sayakpaul in #9512
- [Core] fix variant-identification. by @sayakpaul in #9253
- [refactor] remove conv_cache from CogVideoX VAE by @a-r-r-o-w in #9524
- [train_instruct_pix2pix.py]Fix the LR schedulers when
num_train_epochs
is passed in a distributed training env by @AnandK27 in #9316 - [chore] fix: retain memory utility. by @sayakpaul in #9543
- [LoRA] support Kohya Flux LoRAs that have text encoders as well by @sayakpaul in #9542
- Add beta sigmas to other schedulers and update docs by @hlky in #9538
- Add PAG support to StableDiffusionControlNetPAGInpaintPipeline by @juancopi81 in #8875
- Support bfloat16 for Upsample2D by @darhsu in #9480
- fix cogvideox autoencoder decode by @Xiang-cd in #9569
- [sd3] make sure height and size are divisible by
16
by @yiyixuxu in #9573 - fix xlabs FLUX lora conversion typo by @Clement-Lelievre in #9581
- [Chore] add a note on the versions in Flux LoRA integration tests by @sayakpaul in #9598
- fix vae dtype when accelerate config using --mixed_precision="fp16" by @xduzhangjiayu in #9601
- refac: docstrings in import_utils.py by @yijun-lee in #9583
- Fix for use_safetensors parameters, allow use of parameter on loading submodels by @elismasilva in #9576)
- Update distributed_inference.md to include
transformer.device_map
by @sayakpaul in #9553 - fix: CogVideox train dataset _preprocess_data crop video by @glide-the in #9574
- [LoRA] Handle DoRA better by @sayakpau...
v0.30.3: CogVideoX Image-to-Video and Video-to-Video
This patch release adds Diffusers support for the upcoming CogVideoX-5B-I2V release (an Image-to-Video generation model)! The model weights will be available by end of the week on the HF Hub at THUDM/CogVideoX-5b-I2V
(Link). Stay tuned for the release!
This release features two new pipelines:
- CogVideoXImageToVideoPipeline
- CogVideoXVideoToVideoPipeline
Additionally, we now have support for tiled encoding in the CogVideoX VAE. This can be enabled by calling the vae.enable_tiling()
method, and it is used in the new Video-to-Video pipeline to encode sample videos to latents in a memory-efficient manner.
CogVideoXImageToVideoPipeline
The code below demonstrates how to use the new image-to-video pipeline:
import torch from diffusers import CogVideoXImageToVideoPipeline from diffusers.utils import export_to_video, load_image
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) pipe.to("cuda")
Optionally, enable memory optimizations.
If enabling CPU offloading, remember to remove pipe.to("cuda")
above
pipe.enable_model_cpu_offload() pipe.vae.enable_tiling()
prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" ) video = pipe(image, prompt, use_dynamic_cfg=True) export_to_video(video.frames[0], "output.mp4", fps=8)
![]() |
CogVideoXImageToVideoExample.mp4 |
---|
CogVideoXVideoToVideoPipeline
The code below demonstrates how to use the new video-to-video pipeline:
import torch from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline from diffusers.utils import export_to_video, load_video
Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-trial", torch_dtype=torch.bfloat16) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) pipe.to("cuda")
input_video = load_video( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" ) prompt = ( "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " "moons, but the remainder of the scene is mostly realistic." )
video = pipe( video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50 ).frames[0] export_to_video(video, "output.mp4", fps=8)
CogVideoXVideoToVideoExample.mp4
Shoutout to @tin2tin for the awesome demonstration!
Refer to our documentation to learn more about it.
All commits
- [core] Support VideoToVideo with CogVideoX by @a-r-r-o-w in #9333
- [core] CogVideoX memory optimizations in VAE encode by @a-r-r-o-w in #9340
- [CI] Quick fix for Cog Video Test by @DN6 in #9373
- [refactor] move positional embeddings to patch embed layer for CogVideoX by @a-r-r-o-w in #9263
- CogVideoX-5b-I2V support by @zRzRzRzRzRzRzR in #9418
v0.30.2: Update from single file default repository
All commits
- update runway repo for single_file by @yiyixuxu in #9323
- Fix Flux CLIP prompt embeds repeat for num_images_per_prompt > 1 by @DN6 in #9280
- [IP Adapter] Fix cache_dir and local_files_only for image encoder by @asomoza in #9272
V0.30.1: CogVideoX-5B & Bug fixes
CogVideoX-5B
This patch release adds diffusers support for the upcoming CogVideoX-5B release! The model weights will be available next week on the Huggingface Hub at THUDM/CogVideoX-5b
. Stay tuned for the release!
Additionally, we have implemented VAE tiling feature, which reduces the memory requirement for CogVideoX models. With this update, the total memory requirement is now 12GB for CogVideoX-2B and 21GB for CogVideoX-5B (with CPU offloading). To Enable this feature, simply call enable_tiling()
on the VAE.
The code below shows how to generate a video with CogVideoX-5B
import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video
prompt = "Tracking shot,late afternoon light casting long shadows,a cyclist in athletic gear pedaling down a scenic mountain road,winding path with trees and a lake in the background,invigorating and adventurous atmosphere."
pipe = CogVideoXPipeline.from_pretrained( "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 )
pipe.enable_model_cpu_offload() pipe.vae.enable_tiling()
video = pipe( prompt=prompt, num_videos_per_prompt=1, num_inference_steps=50, num_frames=49, guidance_scale=6, ).frames[0]
export_to_video(video, "output.mp4", fps=8)
000000.mp4
Refer to our documentation to learn more about it.
All commits
- Update Video Loading/Export to use
imageio
by @DN6 in #9094 - [refactor] CogVideoX followups + tiled decoding support by @a-r-r-o-w in #9150
- Add Learned PE selection for Auraflow by @cloneofsimo in #9182
- [Single File] Fix configuring scheduler via legacy kwargs by @DN6 in #9229
- [Flux LoRA] support parsing alpha from a flux lora state dict. by @sayakpaul in #9236
- [tests] fix broken xformers tests by @a-r-r-o-w in #9206
- Cogvideox-5B Model adapter change by @zRzRzRzRzRzRzR in #9203
- [Single File] Support loading Comfy UI Flux checkpoints by @DN6 in #9243
v0.30.0: New Pipelines (Flux, Stable Audio, Kolors, CogVideoX, Latte, and more), New Methods (FreeNoise, SparseCtrl), and New Refactors
New pipelines
Image taken from the Lumina’s GitHub.
This release features many new pipelines. Below, we provide a list:
Audio pipelines 🎼
Video pipelines 📹
- Latte (thanks to @maxin-cn for the contribution through #8404)
- CogVideoX (thanks to @zRzRzRzRzRzRzR for the contribution through #9082)
Image pipelines 🎇
Be sure to check out the respective docs to know more about these pipelines. Some additional pointers are below for curious minds:
- Lumina introduces a new DiT architecture that is multilingual in nature.
- Kolors is inspired by SDXL and is also multilingual in nature.
- Flux introduces the largest (more than 12B parameters!) open-sourced DiT variant available to date. For efficient DreamBooth + LoRA training, we recommend @bghira’s guide here.
- We have worked on a guide that shows how to quantize these large pipelines for memory efficiency with
optimum.quanto
. Check it out here. - CogVideoX introduces a novel and truly 3D VAE into Diffusers.
Perturbed Attention Guidance (PAG)
Without PAG | With PAG |
---|---|
![]() |
![]() |
We already had community pipelines for PAG, but given its usefulness, we decided to make it a first-class citizen of the library. We have a central usage guide for PAG here, which should be the entry point for a user interested in understanding and using PAG for their use cases. We currently support the following pipelines with PAG:
StableDiffusionPAGPipeline
StableDiffusion3PAGPipeline
StableDiffusionControlNetPAGPipeline
StableDiffusionXLPAGPipeline
StableDiffusionXLPAGImg2ImgPipeline
StableDiffusionXLPAGInpaintPipeline
StableDiffusionXLControlNetPAGPipeline
StableDiffusion3PAGPipeline
PixArtSigmaPAGPipeline
HunyuanDiTPAGPipeline
AnimateDiffPAGPipeline
KolorsPAGPipeline
If you’re interested in helping us extend our PAG support for other pipelines, please check out this thread.
Special thanks to Ahn Donghoon (@sunovivid), the author of PAG, for helping us with the integration and adding PAG support to SD3.
AnimateDiff with SparseCtrl
SparseCtrl introduces methods of controllability into text-to-video diffusion models leveraging signals such as line/edge sketches, depth maps, and RGB images by incorporating an additional condition encoder, inspired by ControlNet, to process these signals in the AnimateDiff framework. It can be applied to a diverse set of applications such as interpolation or video prediction (filling in the gaps between sequence of images for animation), personalized image animation, sketch-to-video, depth-to-video, and more. It was introduced in SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models.
There are two SparseCtrl-specific checkpoints and a Motion LoRA made available by the authors namely:
Scribble Interpolation Example:
import torch
from diffusers import AnimateDiffSparseControlNetPipeline, AutoencoderKL, MotionAdapter, SparseControlNetModel from diffusers.schedulers import DPMSolverMultistepScheduler from diffusers.utils import export_to_gif, load_image
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3", torch_dtype=torch.float16).to(device) controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectrl-scribble", torch_dtype=torch.float16).to(device) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device) pipe = AnimateDiffSparseControlNetPipeline.from_pretrained( "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=motion_adapter, controlnet=controlnet, vae=vae, scheduler=scheduler, torch_dtype=torch.float16, ).to(device) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True) pipe.load_lora_weights("guoyww/animatediff-motion-lora-v1-5-3", adapter_name="motion_lora") pipe.fuse_lora(lora_scale=1.0)
prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" negative_prompt = "low quality, worst quality, letterboxed"
image_files = [ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png" ] condition_frame_indices = [0, 8, 15] conditioning_frames = [load_image(img_file) for img_file in image_files]
video = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, conditioning_frames=conditioning_frames, controlnet_conditioning_scale=1.0, controlnet_frame_indices=condition_frame_indices, generator=torch.Generator().manual_seed(1337), ).frames[0] export_to_gif(video, "output.gif")
📜 Check out the docs here.
FreeNoise for AnimateDiff
FreeNoise is a training-free method that allows extending the generative capabilities of pretrained video diffusion models beyond their existing context/frame limits.
Instead of initializing noises for all frames, FreeNoise reschedules a sequence of noises for long-range correlation and performs temporal attention over them using a window-based function. We have added FreeNoise to the AnimateDiff family of models in Diffusers, allowing them to generate videos beyond their default 32 frame limit.
import torch from diffusers import AnimateDiffPipeline, MotionAdapter, EulerAncestralDiscreteScheduler from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) pipe = AnimateDiffPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16) pipe.scheduler = EulerAncestralDiscreteScheduler( beta_schedule="linear", beta_start=0.00085, beta_end=0.012, )
pipe.enable_free_noise() pipe.vae.enable_slicing()
pipe.enable_model_cpu_offload() frames = pipe( "An astronaut riding a horse on Mars.", num_frames=64, num_inference_steps=20, guidance_scale=7.0, decode_chunk_size=2, ).frames[0]
export_to_gif(frames, "freenoise-64.gif")
LoRA refactor
We have significantly refactored the loader classes associated with LoRA. Going forward, this will help in adding LoRA support for new pipelines and models. We now have a LoraBaseMixin
class which is subclassed by the different pipeline-level LoRA loading classes such as StableDiffusionXLLoraLoaderMixin
. This document provides an overview of the available classes.
Additionally, we have increased the coverage of methods within the PeftAdapterMixin class. This refactoring allows all the supported models to share common LoRA functionalities such set_adapter()
, add_adapter()
, and so on.
To learn more details, please follow this PR. If you see any LoRA-related iss...