Accelerate inference (original) (raw)
Diffusion models are slow at inference because generation is an iterative process where noise is gradually refined into an image or video over a certain number of “steps”. To speedup this process, you can try experimenting with different schedulers, reduce the precision of the model weights for faster computations, use more memory-efficient attention mechanisms, and more.
Combine and use these techniques together to make inference faster than using any single technique on its own.
This guide will go over how to accelerate inference.
Model data type
The precision and data type of the model weights affect inference speed because a higher precision requires more memory to load and more time to perform the computations. PyTorch loads model weights in float32 or full precision by default, so changing the data type is a simple way to quickly get faster inference.
bfloat16
float16
TensorFloat-32
bfloat16 is similar to float16 but it is more robust to numerical errors. Hardware support for bfloat16 varies, but most modern GPUs are capable of supporting bfloat16.
import torch from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" pipeline(prompt, num_inference_steps=30).images[0]
Scaled dot product attention
Memory-efficient attention optimizes for inference speed and memory usage!
Scaled dot product attention (SDPA) implements several attention backends, FlashAttention, xFormers, and a native C++ implementation. It automatically selects the most optimal backend for your hardware.
SDPA is enabled by default if you’re using PyTorch >= 2.0 and no additional changes are required to your code. You could try experimenting with other attention backends though if you’d like to choose your own. The example below uses the torch.nn.attention.sdpa_kernel context manager to enable efficient attention.
from torch.nn.attention import SDPBackend, sdpa_kernel import torch from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): image = pipeline(prompt, num_inference_steps=30).images[0]
torch.compile
torch.compile accelerates inference by compiling PyTorch code and operations into optimized kernels. Diffusers typically compiles the more compute-intensive models like the UNet, transformer, or VAE.
Enable the following compiler settings for maximum speed (refer to the full list for more options).
import torch from diffusers import StableDiffusionXLPipeline
torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True
Load and compile the UNet and VAE. There are several different modes you can choose from, but "max-autotune"
optimizes for the fastest speed by compiling to a CUDA graph. CUDA graphs effectively reduces the overhead by launching multiple GPU operations through a single CPU operation.
With PyTorch 2.3.1, you can control the caching behavior of torch.compile. This is particularly beneficial for compilation modes like "max-autotune"
which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the Compile Time Caching in torch.compile tutorial.
Changing the memory layout to channels_last also optimizes memory and inference speed.
pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.unet.to(memory_format=torch.channels_last) pipeline.vae.to(memory_format=torch.channels_last) pipeline.unet = torch.compile( pipeline.unet, mode="max-autotune", fullgraph=True ) pipeline.vae.decode = torch.compile( pipeline.vae.decode, mode="max-autotune", fullgraph=True )
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" pipeline(prompt, num_inference_steps=30).images[0]
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
Graph breaks
It is important to specify fullgraph=True
in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.
- latents = unet(
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds -).sample
- latents = unet(
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False +)[0]
GPU sync
The step()
function is called on the scheduler each time after the denoiser makes a prediction, and the sigmas
variable is indexed. When placed on the GPU, it introduces latency because of the communication sync between the CPU and GPU. It becomes more evident when the denoiser has already been compiled.
In general, the sigmas
should stay on the CPU to avoid the communication sync and latency.
Dynamic quantization
Dynamic quantization improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
The example below applies dynamic int8 quantization to the UNet and VAE with the torchao library.
Refer to our torchao docs to learn more about how to use the Diffusers torchao integration.
Configure the compiler tags for maximum speed.
import torch from torchao import apply_dynamic_quant from diffusers import StableDiffusionXLPipeline
torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.use_mixed_mm = True
Filter out some linear layers in the UNet and VAE which don’t benefit from dynamic quantization with the dynamic_quant_filter_fn.
pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to("cuda")
apply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn) apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" pipeline(prompt, num_inference_steps=30).images[0]
Fused projection matrices
The fuse_qkv_projections method is experimental and support is limited to mostly Stable Diffusion pipelines. Take a look at this PR to learn more about how to enable it for other pipelines
An input is projected into three subspaces, represented by the projection matrices Q, K, and V, in an attention block. These projections are typically calculated separately, but you can horizontally combine these into a single matrix and perform the projection in a single step. It increases the size of the matrix multiplications of the input projections and also improves the impact of quantization.
pipeline.fuse_qkv_projections()