Reduce memory usage (original) (raw)

Modern diffusion models like Flux and Wan have billions of parameters that take up a lot of memory on your hardware for inference. This is challenging because common GPUs often don’t have sufficient memory. To overcome the memory limitations, you can use more than one GPU (if available), offload some of the pipeline components to the CPU, and more.

This guide will show you how to reduce your memory usage.

Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.

Multiple GPUs

If you have access to more than one GPU, there a few options for efficiently loading and distributing a large model across your hardware. These features are supported by the Accelerate library, so make sure it is installed first.

pip install -U accelerate

Sharded checkpoints

Loading large checkpoints in several shards in useful because the shards are loaded one at a time. This keeps memory usage low, only requiring enough memory for the model size and the largest shard size. We recommend sharding when the fp32 checkpoint is greater than 5GB. The default shard size is 5GB.

Shard a checkpoint in save_pretrained() with the max_shard_size parameter.

from diffusers import AutoModel

unet = AutoModel.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet" ) unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")

Now you can use the sharded checkpoint, instead of the regular checkpoint, to save memory.

import torch from diffusers import AutoModel, StableDiffusionXLPipeline

unet = AutoModel.from_pretrained( "username/sdxl-unet-sharded", torch_dtype=torch.float16 ) pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16 ).to("cuda")

Device placement

Device placement is an experimental feature and the API may change. Only the balanced strategy is supported at the moment. We plan to support additional mapping strategies in the future.

The device_map parameter controls how the model components in a pipeline are distributed across devices. The balanced device placement strategy evenly splits the pipeline across all available devices.

import torch from diffusers import AutoModel, StableDiffusionXLPipeline

pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="balanced" )

You can inspect a pipeline’s device map with hf_device_map.

print(pipeline.hf_device_map) {'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}

The device_map parameter also works on the model-level. This is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Instead of balanced, set it to "auto" to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the Model sharding docs for more details.

import torch from diffusers import AutoModel

transformer = AutoModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="transformer", device_map="auto", torch_dtype=torch.bfloat16 )

For more fine-grained control, pass a dictionary to enforce the maximum GPU memory to use on each device. If a device is not in max_memory, it is ignored and pipeline components won’t be distributed to it.

import torch from diffusers import AutoModel, StableDiffusionXLPipeline

max_memory = {0:"1GB", 1:"1GB"} pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="balanced", max_memory=max_memory )

Diffusers uses the maxmium memory of all devices by default, but if they don’t fit on the GPUs, then you’ll need to use a single GPU and offload to the CPU with the methods below.

Use the reset_device_map() method to reset the device_map. This is necessary if you want to use methods like .to(), enable_sequential_cpu_offload(), and enable_model_cpu_offload() on a pipeline that was device-mapped.

pipeline.reset_device_map()

VAE slicing

VAE slicing saves memory by splitting large batches of inputs into a single batch of data and separately processing them. This method works best when generating more than one image at a time.

For example, if you’re generating 4 images at once, decoding would increase peak activation memory by 4x. VAE slicing reduces this by only decoding 1 image at a time instead of all 4 images at once.

Call enable_vae_slicing() to enable sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.

import torch from diffusers import AutoModel, StableDiffusionXLPipeline

pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, ).to("cuda") pipeline.enable_vae_slicing() pipeline(["An astronaut riding a horse on Mars"]*32).images[0] print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

[!WARNING][AutoencoderKLWan](/docs/diffusers/main/en/api/models/autoencoder_kl_wan#diffusers.AutoencoderKLWan) and AsymmetricAutoencoderKL don’t support slicing.

VAE tiling

VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time.

Call enable_vae_tiling() to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they’re decoded separately, but there shouldn’t be any obvious seams between the tiles. Tiling is disabled for resolutions lower than a pre-specified (but configurable) limit. For example, this limit is 512x512 for the VAE in StableDiffusionPipeline.

import torch from diffusers import AutoPipelineForImage2Image from diffusers.utils import load_image

pipeline = AutoPipelineForImage2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.enable_vae_tiling()

init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" pipeline(prompt, image=init_image, strength=0.5).images[0] print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

[!WARNING][AutoencoderKLWan](/docs/diffusers/main/en/api/models/autoencoder_kl_wan#diffusers.AutoencoderKLWan) and AsymmetricAutoencoderKL don’t support tiling.

CPU offloading

CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn’t required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.

CPU offloading dramatically reduces memory usage, but it is also extremely slow because submodules are passed back and forth multiple times between devices. It can often be impractical due to how slow it is.

Don’t move the pipeline to CUDA before calling enable_sequential_cpu_offload(), otherwise the amount of memory saved is only minimal (refer to this issue for more details). This is a stateful operation that installs hooks on the model.

Call enable_sequential_cpu_offload() to enable it on a pipeline.

import torch from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ) pipeline.enable_sequential_cpu_offload()

pipeline( prompt="An astronaut riding a horse on Mars", guidance_scale=0., height=768, width=1360, num_inference_steps=4, max_sequence_length=256, ).images[0] print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

Model offloading

Model offloading moves entire models to the GPU instead of selectively moving some layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of CPU offloading and makes model offloading a faster alternative. The tradeoff is memory savings won’t be as large.

Keep in mind that if models are reused outside the pipeline after hookes have been installed (see Removing Hooks for more details), you need to run the entire pipeline and models in the expected order to properly offload them. This is a stateful operation that installs hooks on the model.

Call enable_model_cpu_offload() to enable it on a pipeline.

import torch from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ) pipline.enable_model_cpu_offload()

pipeline( prompt="An astronaut riding a horse on Mars", guidance_scale=0., height=768, width=1360, num_inference_steps=4, max_sequence_length=256, ).images[0] print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

enable_model_cpu_offload() also helps when you’re using the encode_prompt() method on its own to generate the text encoders hidden state.

Group offloading

Group offloading moves groups of internal layers (torch.nn.ModuleList or torch.nn.Sequential) to the CPU. It uses less memory than model offloading and it is faster than CPU offloading because it reduces communication overhead.

Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading’s device casting mechanism.

Call enable_group_offload() to enable it for standard Diffusers model components that inherit from ModelMixin. For other model components that don’t inherit from ModelMixin, such as a generic torch.nn.Module, use apply_group_offloading() instead.

The offload_type parameter can be set to block_level or leaf_level.

import torch from diffusers import CogVideoXPipeline from diffusers.hooks import apply_group_offloading from diffusers.utils import export_to_video

onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level") pipeline.vae.enable_group_offload(onload_device=onload_device, offload_type="leaf_level")

apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)

prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") export_to_video(video, "output.mp4", fps=8)

CUDA stream

The use_stream parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to CPU offloading. It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.

Set record_stream=True for more of a speedup at the cost of slightly increased memory usage. Refer to the torch.Tensor.record_stream docs to learn more.

When use_stream=True on VAEs with tiling enabled, make sure to do a dummy forward pass (possible with dummy inputs as well) before inference to avoid device mismatch errors. This may not work on all implementations, so feel free to open an issue if you encounter any problems.

If you’re using block_level group offloading with use_stream enabled, the num_blocks_per_group parameter should be set to 1, otherwise a warning will be raised.

pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)

The low_cpu_mem_usage parameter can be set to True to reduce CPU memory usage when using streams during group offloading. It is best for leaf_level offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.

Layerwise casting

Layerwise casting stores weights in a smaller data format (for example, torch.float8_e4m3fn and torch.float8_e5m2) to use less memory and upcasts those weights to a higher precision like torch.float16 or torch.bfloat16 for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.

Layerwise casting may not work with all models if the forward implementation contains internal typecasting of weights. The current implementation of layerwise casting assumes the forward pass is independent of the weight precision and the input datatypes are always specified in compute_dtype (see here for an incompatible implementation).

Layerwise casting may also fail on custom modeling implementations with PEFT layers. There are some checks available but they are not extensively tested or guaranteed to work in all cases.

Call enable_layerwise_casting() to set the storage and computation datatypes.

import torch from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel from diffusers.utils import export_to_video

transformer = CogVideoXTransformer3DModel.from_pretrained( "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 ) transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", transformer=transformer, torch_dtype=torch.bfloat16 ).to("cuda") prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") export_to_video(video, "output.mp4", fps=8)

The apply_layerwise_casting() method can also be used if you need more control and flexibility. It can be partially applied to model layers by calling it on specific internal modules. Use the skip_modules_pattern or skip_modules_classes parameters to specify modules to avoid, such as the normalization and modulation layers.

import torch from diffusers import CogVideoXTransformer3DModel from diffusers.hooks import apply_layerwise_casting

transformer = CogVideoXTransformer3DModel.from_pretrained( "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 )

apply_layerwise_casting( transformer, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16, skip_modules_classes=["norm"], non_blocking=True, )

torch.channels_last

torch.channels_last flips how tensors are stored from (batch size, channels, height, width) to (batch size, heigh, width, channels). This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.

Not all operators currently support the channels-last format and may result in worst performance, but it is still worth trying.

print(pipeline.unet.conv_out.state_dict()["weight"].stride())
pipeline.unet.to(memory_format=torch.channels_last)
print( pipeline.unet.conv_out.state_dict()["weight"].stride() )

torch.jit.trace

torch.jit.trace records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or ScriptFunction can be compiled.

import time import torch from diffusers import StableDiffusionPipeline import functools

torch.set_grad_enabled(False)

n_experiments = 2 unet_runs_per_experiment = 50

def generate_inputs(): sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16) timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999 encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16) return sample, timestep, encoder_hidden_states

pipeline = StableDiffusionPipeline.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, ).to("cuda") unet = pipeline.unet unet.eval() unet.to(memory_format=torch.channels_last)
unet.forward = functools.partial(unet.forward, return_dict=False)

for _ in range(3): with torch.inference_mode(): inputs = generate_inputs() orig_output = unet(*inputs)

print("tracing..") unet_traced = torch.jit.trace(unet, inputs) unet_traced.eval() print("done tracing")

for _ in range(5): with torch.inference_mode(): inputs = generate_inputs() orig_output = unet_traced(*inputs)

with torch.inference_mode(): for _ in range(n_experiments): torch.cuda.synchronize() start_time = time.time() for _ in range(unet_runs_per_experiment): orig_output = unet_traced(*inputs) torch.cuda.synchronize() print(f"unet traced inference took {time.time() - start_time:.2f} seconds") for _ in range(n_experiments): torch.cuda.synchronize() start_time = time.time() for _ in range(unet_runs_per_experiment): orig_output = unet(*inputs) torch.cuda.synchronize() print(f"unet inference took {time.time() - start_time:.2f} seconds")

unet_traced.save("unet_traced.pt")

Replace the pipeline’s UNet with the traced version.

import torch from diffusers import StableDiffusionPipeline from dataclasses import dataclass

@dataclass class UNet2DConditionOutput: sample: torch.Tensor

pipeline = StableDiffusionPipeline.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, ).to("cuda")

unet_traced = torch.jit.load("unet_traced.pt")

class TracedUNet(torch.nn.Module): def init(self): super().init() self.in_channels = pipe.unet.config.in_channels self.device = pipe.unet.device

def forward(self, latent_model_input, t, encoder_hidden_states):
    sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
    return UNet2DConditionOutput(sample=sample)

pipeline.unet = TracedUNet()

with torch.inference_mode(): image = pipe([prompt] * 1, num_inference_steps=50).images[0]

Memory-efficient attention

Memory-efficient attention optimizes for memory usage and [inference speed](./fp16#scaled-dot-product-attention!

The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.

By default, if PyTorch >= 2.0 is installed, scaled dot-product attention (SDPA) is used. You don’t need to make any additional changes to your code.

SDPA supports FlashAttention and xFormers as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.

You can explicitly use xFormers with the enable_xformers_memory_efficient_attention() method.

import torch from diffusers import StableDiffusionXLPipeline

pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, ).to("cuda") pipeline.enable_xformers_memory_efficient_attention()

Call disable_xformers_memory_efficient_attention() to disable it.

pipeline.disable_xformers_memory_efficient_attention()

< > Update on GitHub