Releases · huggingface/diffusers (original) (raw)

v0.33.1: fix ftfy import

Diffusers 0.33.0: New Image and Video Models, Memory Optimizations, Caching Methods, Remote VAEs, New Training Scripts, and more

New Pipelines for Video Generation

Wan 2.1

Wan2.1 is a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. The model release includes 4 different model variants and three different pipelines for Text to Video, Image to Video and Video to Video.

Check out the docs here to learn more.

LTX Video 0.9.5

LTX Video 0.9.5 is the updated version of the super-fast LTX Video model series. The latest model introduces additional conditioning options, such as keyframe-based animation and video extension (both forward and backward).

To support these additional conditioning inputs, we’ve introduced the LTXConditionPipeline and LTXVideoCondition object.

To learn more about the usage, check out the docs here.

Hunyuan Image to Video

Hunyuan utilizes a pre-trained Multimodal Large Language Model (MLLM) with a Decoder-Only architecture as the text encoder. The input image is processed by the MLLM to generate semantic image tokens. These tokens are then concatenated with the video latent tokens, enabling comprehensive full-attention computation across the combined data and seamlessly integrating information from both the image and its associated caption.

To learn more, check out the docs here.

Others

New Pipelines for Image Generation

Sana-Sprint

SANA-Sprint is an efficient diffusion model for ultra-fast text-to-image generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4, rivaling the quality of models like Flux.

Shoutout to @lawrence-cj for their help and guidance on this PR.

Check out the pipeline docs of SANA-Sprint to learn more.

Lumina2

Lumina-Image-2.0 is a 2B parameter flow-based diffusion transformer for text-to-image generation released under the Apache 2.0 license.

Check out the docs to learn more. Thanks to @zhuole1025 for contributing this through this PR.

One can also LoRA fine-tune Lumina2, taking advantage of its Apach2.0 licensing. Check out the guide for more details.

Omnigen

OmniGen is a unified image generation model that can handle multiple tasks including text-to-image, image editing, subject-driven generation, and various computer vision tasks within a single framework. The model consists of a VAE, and a single transformer based on Phi-3 that handles text and image encoding as well as the diffusion process.

Check out the docs to learn more about OmniGen. Thanks to @staoxiao for contributing OmniGen in this PR.

Others

New Memory Optimizations

Layerwise Casting

PyTorch supports torch.float8_e4m3fn and torch.float8_e5m2 as weight storage dtypes, but they can’t be used for computation on many devices due to unimplemented kernel support.

However, you can still use these dtypes to store model weights in FP8 precision and upcast them to a widely supported dtype such as torch.float16 or torch.bfloat16 on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. This can potentially cut down the VRAM requirements of a model by 50%.

Code

import torch from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel from diffusers.utils import export_to_video

model_id = "THUDM/CogVideoX-5b"

Load the model in bfloat16 and enable layerwise casting

transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

Load the pipeline

pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) pipe.to("cuda")

prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] export_to_video(video, "output.mp4", fps=8)

Group Offloading

Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either torch.nn.ModuleList or torch.nn.Sequential), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.

On CUDA devices, we also have the option to enable using layer prefetching with CUDA Streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed which makes inference substantially faster while still keeping VRAM requirements very low. With this, we introduce the idea of overlapping computation with data transfer.

One thing to note is that using CUDA streams can cause a considerable spike in CPU RAM usage. Please ensure that the available CPU RAM is 2 times the size of the model if you choose to set use_stream=True. You can reduce CPU RAM usage by setting low_cpu_mem_usage=True. This should limit the CPU RAM used to be roughly the same as the size of the model, but will introduce slight latency in the inference process.

You can also use record_stream=True when using use_stream=True to obtain more speedups at the expense of slightly increased memory usage.

Code

import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video

Load the pipeline

onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

We can utilize the enable_group_offload method for Diffusers model implementations

pipe.transformer.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True )

prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]

This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.

print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") export_to_video(video, "output.mp4", fps=8)

Group offloading can also be applied to non-Diffusers models such as text encoders from the transformers library.

Code

import torch from diffusers import CogVideoXPipeline from diffusers.hooks import apply_group_offloading from diffusers.utils import export_to_video

Load the pipeline

onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

For any other model implementations, the apply_group_offloading function can be used

apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)

Remote Components

Remote components are an experimental feature designed to offload memory-intensive steps of t...

Read more

v0.32.2

Fixes for Flux Single File loading, LoRA loading for 4bit BnB Flux, Hunyuan Video

This patch release

All commits

v0.32.1

TorchAO Quantizer fixes

This patch release fixes a few bugs related to the TorchAO Quantizer introduced in v0.32.0.

Refer to our documentation to learn more about how to use different quantization backends.

All commits

Diffusers 0.32.0: New video pipelines, new image pipelines, new quantization backends, new training scripts, and more

hunyuan-output.mp4

This release took a while, but it has many exciting updates. It contains several new pipelines for image and video generation, new quantization backends, and more.

Going forward, to provide more transparency to the community about ongoing developments and releases in Diffusers, we will be making use of a roadmap tracker.

New Video Generation Pipelines 📹

Open video generation models are on the rise, and we’re pleased to provide comprehensive integration support for all of them. The following video pipelines are bundled in this release:

Check out this section to learn more about the fine-tuning options available for these new video models.

New Image Generation Pipelines

Important Note about the new Flux Models

We can combine the regular Flux.1 Dev LoRAs with Flux Control LoRAs, Flux Control, and Flux Fill. For example, you can enable few-steps inference with Flux Fill using:

from diffusers import FluxFillPipeline from diffusers.utils import load_image import torch

pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda")

adapter_id = "alimama-creative/FLUX.1-Turbo-Alpha" pipe.load_lora_weights(adapter_id)

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png") mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

image = pipe( prompt="a white paper cup", image=image, mask_image=mask, height=1632, width=1232, guidance_scale=30, num_inference_steps=8, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0) ).images[0] image.save("flux-fill-dev.png")

To learn more, check out the documentation.

Note

SANA is a small model compared to other models like Flux and Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. We support LoRA fine-tuning of SANA. Check out this section for more details.

Acknowledgements

New Quantization Backends

Please be aware of the following caveats:

New training scripts

This release features many new training scripts for the community to play:

All commits

Read more

v0.31.0

v0.31.0: Stable Diffusion 3.5 Large, CogView3, Quantization, Training Scripts, and more

Stable Diffusion 3.5 Large

Stability AI’s latest text-to-image generation model is Stable Diffusion 3.5 Large. SD3.5 Large is the next iteration of Stable Diffusion 3. It comes with two checkpoints (both of which have 8B params):

Make sure to fill up the form by going to the model page, and then run huggingface-cli login before running the code below.

make sure to update diffusers

pip install -U diffusers

import torch from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16 ).to("cuda")

image = pipe( prompt="a photo of a cat holding a sign that says hello world", negative_prompt="", num_inference_steps=40, height=1024, width=1024, guidance_scale=4.5, ).images[0]

image.save("sd3_hello_world.png")

Follow the documentation to know more.

Cogview3-plus

We added a new text-to-image model, Cogview3-plus, from the THUDM team! The model is DiT-based and supports image generation from 512 to 2048px. Thanks to @zRzRzRzRzRzRzR for contributing it!

from diffusers import CogView3PlusPipeline import torch

pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.float16).to("cuda")

Enable it to reduce GPU memory usage

pipe.enable_model_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling()

prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."

image = pipe( prompt=prompt, guidance_scale=7.0, num_images_per_prompt=1, num_inference_steps=50, width=1024, height=1024, ).images[0]

image.save("cogview3.png")

Refer to the documentation to know more.

Quantization

We have landed native quantization support in Diffusers, starting with bitsandbytes as its first quantization backend. With this, we hope to see large diffusion models becoming much more accessible to run on consumer hardware.

The example below shows how to run Flux.1 Dev with the NF4 data-type. Make sure you install the libraries:

pip install -Uq git+https://github.com/huggingface/transformers@main pip install -Uq bitsandbytes pip install -Uq diffusers

from diffusers import BitsAndBytesConfig, FluxTransformer2DModel import torch

ckpt_id = "black-forest-labs/FLUX.1-dev" nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model_nf4 = FluxTransformer2DModel.from_pretrained( ckpt_id, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.bfloat16 )

Then, we use model_nf4 to instantiate the FluxPipeline:

from diffusers import FluxPipeline

pipeline = StableDiffusion3Pipeline.from_pretrained( ckpt_id, transformer=model_nf4, torch_dtype=torch.bfloat16 ) pipeline.enable_model_cpu_offload()

prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature's body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree. As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"

image = pipeline( prompt=prompt, negative_prompt="", num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512, ).images[0] image.save("whimsical.png")

Follow the documentation here to know more. Additionally, check out this Colab Notebook that runs Flux.1 Dev in an end-to-end manner with NF4 quantization.

Training scripts

We have a fresh bucket of training scripts with this release:

Video model fine-tuning can be quite expensive. So, we have worked on a repository, cogvideox-factory, which provides memory-optimized scripts to fine-tune the Cog family of models.

Misc

All commits

Read more

v0.30.3: CogVideoX Image-to-Video and Video-to-Video

This patch release adds Diffusers support for the upcoming CogVideoX-5B-I2V release (an Image-to-Video generation model)! The model weights will be available by end of the week on the HF Hub at THUDM/CogVideoX-5b-I2V (Link). Stay tuned for the release!

This release features two new pipelines:

Additionally, we now have support for tiled encoding in the CogVideoX VAE. This can be enabled by calling the vae.enable_tiling() method, and it is used in the new Video-to-Video pipeline to encode sample videos to latents in a memory-efficient manner.

CogVideoXImageToVideoPipeline

The code below demonstrates how to use the new image-to-video pipeline:

import torch from diffusers import CogVideoXImageToVideoPipeline from diffusers.utils import export_to_video, load_image

pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) pipe.to("cuda")

Optionally, enable memory optimizations.

If enabling CPU offloading, remember to remove pipe.to("cuda") above

pipe.enable_model_cpu_offload() pipe.vae.enable_tiling()

prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" ) video = pipe(image, prompt, use_dynamic_cfg=True) export_to_video(video.frames[0], "output.mp4", fps=8)

CogVideoXImageToVideoExample.mp4

CogVideoXVideoToVideoPipeline

The code below demonstrates how to use the new video-to-video pipeline:

import torch from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline from diffusers.utils import export_to_video, load_video

Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"

pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-trial", torch_dtype=torch.bfloat16) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) pipe.to("cuda")

input_video = load_video( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" ) prompt = ( "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " "moons, but the remainder of the scene is mostly realistic." )

video = pipe( video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50 ).frames[0] export_to_video(video, "output.mp4", fps=8)

CogVideoXVideoToVideoExample.mp4

Shoutout to @tin2tin for the awesome demonstration!

Refer to our documentation to learn more about it.

All commits

v0.30.2: Update from single file default repository

All commits

V0.30.1: CogVideoX-5B & Bug fixes

CogVideoX-5B

This patch release adds diffusers support for the upcoming CogVideoX-5B release! The model weights will be available next week on the Huggingface Hub at THUDM/CogVideoX-5b. Stay tuned for the release!

Additionally, we have implemented VAE tiling feature, which reduces the memory requirement for CogVideoX models. With this update, the total memory requirement is now 12GB for CogVideoX-2B and 21GB for CogVideoX-5B (with CPU offloading). To Enable this feature, simply call enable_tiling() on the VAE.

The code below shows how to generate a video with CogVideoX-5B

import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video

prompt = "Tracking shot,late afternoon light casting long shadows,a cyclist in athletic gear pedaling down a scenic mountain road,winding path with trees and a lake in the background,invigorating and adventurous atmosphere."

pipe = CogVideoXPipeline.from_pretrained( "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 )

pipe.enable_model_cpu_offload() pipe.vae.enable_tiling()

video = pipe( prompt=prompt, num_videos_per_prompt=1, num_inference_steps=50, num_frames=49, guidance_scale=6, ).frames[0]

export_to_video(video, "output.mp4", fps=8)

000000.mp4

Refer to our documentation to learn more about it.

All commits

v0.30.0: New Pipelines (Flux, Stable Audio, Kolors, CogVideoX, Latte, and more), New Methods (FreeNoise, SparseCtrl), and New Refactors

New pipelines

Untitled

Image taken from the Lumina’s GitHub.

This release features many new pipelines. Below, we provide a list:

Audio pipelines 🎼

Video pipelines 📹

Image pipelines 🎇

Be sure to check out the respective docs to know more about these pipelines. Some additional pointers are below for curious minds:

Perturbed Attention Guidance (PAG)

Without PAG With PAG

We already had community pipelines for PAG, but given its usefulness, we decided to make it a first-class citizen of the library. We have a central usage guide for PAG here, which should be the entry point for a user interested in understanding and using PAG for their use cases. We currently support the following pipelines with PAG:

If you’re interested in helping us extend our PAG support for other pipelines, please check out this thread.
Special thanks to Ahn Donghoon (@sunovivid), the author of PAG, for helping us with the integration and adding PAG support to SD3.

AnimateDiff with SparseCtrl

SparseCtrl introduces methods of controllability into text-to-video diffusion models leveraging signals such as line/edge sketches, depth maps, and RGB images by incorporating an additional condition encoder, inspired by ControlNet, to process these signals in the AnimateDiff framework. It can be applied to a diverse set of applications such as interpolation or video prediction (filling in the gaps between sequence of images for animation), personalized image animation, sketch-to-video, depth-to-video, and more. It was introduced in SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models.

There are two SparseCtrl-specific checkpoints and a Motion LoRA made available by the authors namely:

Scribble Interpolation Example:

import torch

from diffusers import AnimateDiffSparseControlNetPipeline, AutoencoderKL, MotionAdapter, SparseControlNetModel from diffusers.schedulers import DPMSolverMultistepScheduler from diffusers.utils import export_to_gif, load_image

motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3", torch_dtype=torch.float16).to(device) controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectrl-scribble", torch_dtype=torch.float16).to(device) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device) pipe = AnimateDiffSparseControlNetPipeline.from_pretrained( "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=motion_adapter, controlnet=controlnet, vae=vae, scheduler=scheduler, torch_dtype=torch.float16, ).to(device) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True) pipe.load_lora_weights("guoyww/animatediff-motion-lora-v1-5-3", adapter_name="motion_lora") pipe.fuse_lora(lora_scale=1.0)

prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" negative_prompt = "low quality, worst quality, letterboxed"

image_files = [ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png" ] condition_frame_indices = [0, 8, 15] conditioning_frames = [load_image(img_file) for img_file in image_files]

video = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, conditioning_frames=conditioning_frames, controlnet_conditioning_scale=1.0, controlnet_frame_indices=condition_frame_indices, generator=torch.Generator().manual_seed(1337), ).frames[0] export_to_gif(video, "output.gif")

📜 Check out the docs here.

FreeNoise for AnimateDiff

FreeNoise is a training-free method that allows extending the generative capabilities of pretrained video diffusion models beyond their existing context/frame limits.

Instead of initializing noises for all frames, FreeNoise reschedules a sequence of noises for long-range correlation and performs temporal attention over them using a window-based function. We have added FreeNoise to the AnimateDiff family of models in Diffusers, allowing them to generate videos beyond their default 32 frame limit.

import torch from diffusers import AnimateDiffPipeline, MotionAdapter, EulerAncestralDiscreteScheduler from diffusers.utils import export_to_gif

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) pipe = AnimateDiffPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16) pipe.scheduler = EulerAncestralDiscreteScheduler( beta_schedule="linear", beta_start=0.00085, beta_end=0.012, )

pipe.enable_free_noise() pipe.vae.enable_slicing()

pipe.enable_model_cpu_offload() frames = pipe( "An astronaut riding a horse on Mars.", num_frames=64, num_inference_steps=20, guidance_scale=7.0, decode_chunk_size=2, ).frames[0]

export_to_gif(frames, "freenoise-64.gif")

LoRA refactor

We have significantly refactored the loader classes associated with LoRA. Going forward, this will help in adding LoRA support for new pipelines and models. We now have a LoraBaseMixin class which is subclassed by the different pipeline-level LoRA loading classes such as StableDiffusionXLLoraLoaderMixin. This document provides an overview of the available classes.

Additionally, we have increased the coverage of methods within the PeftAdapterMixin class. This refactoring allows all the supported models to share common LoRA functionalities such set_adapter(), add_adapter(), and so on.

To learn more details, please follow this PR. If you see any LoRA-related iss...

Read more