SanaPipeline (original) (raw)

LoRA MPS

SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.

The abstract from the paper is:

We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

This pipeline was contributed by lawrence-cj and chenjy2003. The original codebase can be found here. The original weights can be found under hf.co/Efficient-Large-Model.

Available models:

Model Recommended dtype
Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers torch.bfloat16
Efficient-Large-Model/Sana_1600M_1024px_diffusers torch.float16
Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers torch.float16
Efficient-Large-Model/Sana_1600M_512px_diffusers torch.float16
Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers torch.float16
Efficient-Large-Model/Sana_600M_1024px_diffusers torch.float16
Efficient-Large-Model/Sana_600M_512px_diffusers torch.float16

Refer to this collection for more information.

Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in torch.bfloat16 or torch.float32 for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.

Make sure to pass the variant argument for downloaded checkpoints to use lower disk space. Set it to "fp16" for models with recommended dtype as torch.float16, and "bf16" for models with recommended dtype as torch.bfloat16. By default, torch.float32 weights are downloaded, which use twice the amount of disk storage. Additionally, torch.float32 weights can be downcasted on-the-fly by specifying the torch_dtype argument. Read about it in the docs.

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized SanaPipeline for inference with bitsandbytes.

import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel

quant_config = BitsAndBytesConfig(load_in_8bit=True) text_encoder_8bit = AutoModel.from_pretrained( "Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="text_encoder", quantization_config=quant_config, torch_dtype=torch.float16, )

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) transformer_8bit = SanaTransformer2DModel.from_pretrained( "Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.float16, )

pipeline = SanaPipeline.from_pretrained( "Efficient-Large-Model/Sana_1600M_1024px_diffusers", text_encoder=text_encoder_8bit, transformer=transformer_8bit, torch_dtype=torch.float16, device_map="balanced", )

prompt = "a tiny astronaut hatching from an egg on the moon" image = pipeline(prompt).images[0] image.save("sana.png")

SanaPipeline

class diffusers.SanaPipeline

< source >

( tokenizer: typing.Union[transformers.models.gemma.tokenization_gemma.GemmaTokenizer, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast] text_encoder: Gemma2PreTrainedModel vae: AutoencoderDC transformer: SanaTransformer2DModel scheduler: DPMSolverMultistepScheduler )

Pipeline for text-to-image generation using Sana.

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' num_inference_steps: int = 20 timesteps: typing.List[int] = None sigmas: typing.List[float] = None guidance_scale: float = 4.5 num_images_per_prompt: typing.Optional[int] = 1 height: int = 1024 width: int = 1024 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True clean_caption: bool = False use_resolution_binning: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 300 complex_human_instruction: typing.List[str] = ["Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.', '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 'Here are examples of how to transform or refine prompts:', '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.', '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.', 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 'User Prompt: '] ) → SanaPipelineOutput or tuple

Parameters

If return_dict is True, SanaPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import SanaPipeline

pipe = SanaPipeline.from_pretrained( ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 ... ) pipe.to("cuda") pipe.text_encoder.to(torch.bfloat16) pipe.transformer = pipe.transformer.to(torch.bfloat16)

image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] image[0].save("output.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] do_classifier_free_guidance: bool = True negative_prompt: str = '' num_images_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 300 complex_human_instruction: typing.Optional[typing.List[str]] = None lora_scale: typing.Optional[float] = None )

Parameters

Encodes the prompt into text encoder hidden states.

SanaPAGPipeline

class diffusers.SanaPAGPipeline

< source >

( tokenizer: typing.Union[transformers.models.gemma.tokenization_gemma.GemmaTokenizer, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast] text_encoder: Gemma2PreTrainedModel vae: AutoencoderDC transformer: SanaTransformer2DModel scheduler: FlowMatchEulerDiscreteScheduler pag_applied_layers: typing.Union[str, typing.List[str]] = 'transformer_blocks.0' )

Pipeline for text-to-image generation using Sana. This pipeline supports the use of Perturbed Attention Guidance (PAG).

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' num_inference_steps: int = 20 timesteps: typing.List[int] = None sigmas: typing.List[float] = None guidance_scale: float = 4.5 num_images_per_prompt: typing.Optional[int] = 1 height: int = 1024 width: int = 1024 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True clean_caption: bool = False use_resolution_binning: bool = True callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 300 complex_human_instruction: typing.List[str] = ["Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.', '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 'Here are examples of how to transform or refine prompts:', '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.', '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.', 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 'User Prompt: '] pag_scale: float = 3.0 pag_adaptive_scale: float = 0.0 ) → ImagePipelineOutput or tuple

Parameters

If return_dict is True, ImagePipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import SanaPAGPipeline

pipe = SanaPAGPipeline.from_pretrained( ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", ... pag_applied_layers=["transformer_blocks.8"], ... torch_dtype=torch.float32, ... ) pipe.to("cuda") pipe.text_encoder.to(torch.bfloat16) pipe.transformer = pipe.transformer.to(torch.bfloat16)

image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] image[0].save("output.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] do_classifier_free_guidance: bool = True negative_prompt: str = '' num_images_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 300 complex_human_instruction: typing.Optional[typing.List[str]] = None )

Parameters

Encodes the prompt into text encoder hidden states.

SanaPipelineOutput

class diffusers.pipelines.sana.pipeline_output.SanaPipelineOutput

< source >

( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] )

Parameters

Output class for Sana pipelines.

< > Update on GitHub