Sana-Video (original) (raw)

LoRA MPS

SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.

The abstract from the paper is:

We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. this https URL.

This pipeline was contributed by SANA Team. The original codebase can be found here. The original weights can be found under hf.co/Efficient-Large-Model.

Available models:

Model Recommended dtype
Efficient-Large-Model/SANA-Video_2B_480p_diffusers torch.bfloat16

Refer to this collection for more information.

Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in torch.bfloat16 or torch.float32 for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.

Generation Pipelines

Text-to-Video

Image-to-Video

`

The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.

pipe = SanaVideoPipeline.from_pretrained( "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16, ) pipe.text_encoder.to(torch.bfloat16) pipe.vae.to(torch.float32) pipe.to("cuda")

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." motion_scale = 30 motion_prompt = f" motion score: {motion_scale}." prompt = prompt + motion_prompt

video = pipe( prompt=prompt, negative_prompt=negative_prompt, height=480, width=832, frames=81, guidance_scale=6, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(0), ).frames[0]

export_to_video(video, "sana_video.mp4", fps=16)

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized SanaVideoPipeline for inference with bitsandbytes.

import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel

quant_config = BitsAndBytesConfig(load_in_8bit=True) text_encoder_8bit = AutoModel.from_pretrained( "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="text_encoder", quantization_config=quant_config, torch_dtype=torch.float16, )

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) transformer_8bit = SanaVideoTransformer3DModel.from_pretrained( "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.float16, )

pipeline = SanaVideoPipeline.from_pretrained( "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", text_encoder=text_encoder_8bit, transformer=transformer_8bit, torch_dtype=torch.float16, device_map="balanced", )

model_score = 30 prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." motion_prompt = f" motion score: {model_score}." prompt = prompt + motion_prompt

output = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=480, width=832, num_frames=81, guidance_scale=6.0, num_inference_steps=50 ).frames[0] export_to_video(output, "sana-video-output.mp4", fps=16)

SanaVideoPipeline

class diffusers.SanaVideoPipeline

< source >

( tokenizer: typing.Union[transformers.models.gemma.tokenization_gemma.GemmaTokenizer, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast] text_encoder: Gemma2PreTrainedModel vae: typing.Union[diffusers.models.autoencoders.autoencoder_dc.AutoencoderDC, diffusers.models.autoencoders.autoencoder_kl_wan.AutoencoderKLWan] transformer: SanaVideoTransformer3DModel scheduler: DPMSolverMultistepScheduler )

Parameters

Pipeline for text-to-video generation using Sana. This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' num_inference_steps: int = 50 timesteps: typing.List[int] = None sigmas: typing.List[float] = None guidance_scale: float = 6.0 num_videos_per_prompt: typing.Optional[int] = 1 height: int = 480 width: int = 832 frames: int = 81 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True clean_caption: bool = False use_resolution_binning: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 300 complex_human_instruction: typing.List[str] = ["Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.', '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 'Here are examples of how to transform or refine prompts:', '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.', '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.', 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 'User Prompt: '] ) → SanaVideoPipelineOutput or tuple

Parameters

If return_dict is True, SanaVideoPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated videos

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import SanaVideoPipeline from diffusers.utils import export_to_video

pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") pipe.transformer.to(torch.bfloat16) pipe.text_encoder.to(torch.bfloat16) pipe.vae.to(torch.float32) pipe.to("cuda") motion_score = 30

prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." motion_prompt = f" motion score: {motion_score}." prompt = prompt + motion_prompt

output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=480, ... width=832, ... frames=81, ... guidance_scale=6, ... num_inference_steps=50, ... generator=torch.Generator(device="cuda").manual_seed(42), ... ).frames[0]

export_to_video(output, "sana-video-output.mp4", fps=16)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] do_classifier_free_guidance: bool = True negative_prompt: str = '' num_videos_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 300 complex_human_instruction: typing.Optional[typing.List[str]] = None lora_scale: typing.Optional[float] = None )

Parameters

Encodes the prompt into text encoder hidden states.

SanaImageToVideoPipeline

class diffusers.SanaImageToVideoPipeline

< source >

( tokenizer: typing.Union[transformers.models.gemma.tokenization_gemma.GemmaTokenizer, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast] text_encoder: Gemma2PreTrainedModel vae: typing.Union[diffusers.models.autoencoders.autoencoder_dc.AutoencoderDC, diffusers.models.autoencoders.autoencoder_kl_wan.AutoencoderKLWan] transformer: SanaVideoTransformer3DModel scheduler: FlowMatchEulerDiscreteScheduler )

Parameters

Pipeline for image/text-to-video generation using Sana. This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' num_inference_steps: int = 50 timesteps: typing.List[int] = None sigmas: typing.List[float] = None guidance_scale: float = 6.0 num_videos_per_prompt: typing.Optional[int] = 1 height: int = 480 width: int = 832 frames: int = 81 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True clean_caption: bool = False use_resolution_binning: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 300 complex_human_instruction: typing.List[str] = ["Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.', '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 'Here are examples of how to transform or refine prompts:', '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.', '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.', 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 'User Prompt: '] ) → SanaVideoPipelineOutput or tuple

Parameters

If return_dict is True, SanaVideoPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated videos

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import SanaImageToVideoPipeline from diffusers.utils import export_to_video, load_image

pipe = SanaImageToVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") pipe.transformer.to(torch.bfloat16) pipe.text_encoder.to(torch.bfloat16) pipe.vae.to(torch.float32) pipe.to("cuda") motion_score = 30

prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle." negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." motion_prompt = f" motion score: {motion_score}." prompt = prompt + motion_prompt image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")

output = pipe( ... image=image, ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=480, ... width=832, ... frames=81, ... guidance_scale=6, ... num_inference_steps=50, ... generator=torch.Generator(device="cuda").manual_seed(42), ... ).frames[0]

export_to_video(output, "sana-ti2v-output.mp4", fps=16)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] do_classifier_free_guidance: bool = True negative_prompt: str = '' num_videos_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 300 complex_human_instruction: typing.Optional[typing.List[str]] = None lora_scale: typing.Optional[float] = None )

Parameters

Encodes the prompt into text encoder hidden states.

SanaVideoPipelineOutput

class diffusers.pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput

< source >

( frames: Tensor )

Parameters

Output class for Sana-Video pipelines.

Update on GitHub