Wan (original) (raw)

LoRA

Wan 2.1 by the Alibaba Wan Team.

Generating Videos with Wan 2.1

We will first need to install some additional dependencies.

pip install -u ftfy imageio-ffmpeg imageio

Text to Video Generation

The following example requires 11GB VRAM to run and uses the smaller Wan-AI/Wan2.1-T2V-1.3B-Diffusers model. You can switch it out for the larger Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers if you have at least 35GB VRAM available.

from diffusers import WanPipeline from diffusers.utils import export_to_video

model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"

pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload()

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" num_frames = 33

frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0] export_to_video(frames, "wan-t2v.mp4", fps=16)

You can improve the quality of the generated video by running the decoding step in full precision.

from diffusers import WanPipeline, AutoencoderKLWan from diffusers.utils import export_to_video

model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)

pipe.enable_model_cpu_offload()

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" num_frames = 33

frames = pipe(prompt=prompt, num_frames=num_frames).frames[0] export_to_video(frames, "wan-t2v.mp4", fps=16)

Image to Video Generation

The Image to Video pipeline requires loading the AutoencoderKLWan and the CLIPVisionModel components in full precision. The following example will need at least 35GB of VRAM to run.

import torch import numpy as np from diffusers import AutoencoderKLWan, WanImageToVideoPipeline from diffusers.utils import export_to_video, load_image from transformers import CLIPVisionModel

model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" image_encoder = CLIPVisionModel.from_pretrained( model_id, subfolder="image_encoder", torch_dtype=torch.float32 ) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanImageToVideoPipeline.from_pretrained( model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 )

pipe.enable_model_cpu_offload()

image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" )

max_area = 480 * 832 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height))

prompt = ( "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ) negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

num_frames = 33

output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames, guidance_scale=5.0, ).frames[0] export_to_video(output, "wan-i2v.mp4", fps=16)

First and Last Frame Interpolation

import numpy as np import torch import torchvision.transforms.functional as TF from diffusers import AutoencoderKLWan, WanImageToVideoPipeline from diffusers.utils import export_to_video, load_image from transformers import CLIPVisionModel

model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanImageToVideoPipeline.from_pretrained( model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 ) pipe.to("cuda")

first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")

def aspect_ratio_resize(image, pipe, max_area=720 * 1280): aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height)) return image, height, width

def center_crop_resize(image, height, width):

resize_ratio = max(width / image.width, height / image.height)


width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)

return image, height, width

first_frame, height, width = aspect_ratio_resize(first_frame, pipe) if last_frame.size != first_frame.size: last_frame, _, _ = center_crop_resize(last_frame, height, width)

prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."

output = pipe( image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5 ).frames[0] export_to_video(output, "output.mp4", fps=16)

Video to Video Generation

import torch from diffusers.utils import load_video, export_to_video from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler

model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" vae = AutoencoderKLWan.from_pretrained( model_id, subfolder="vae", torch_dtype=torch.float32 ) pipe = WanVideoToVideoPipeline.from_pretrained( model_id, vae=vae, torch_dtype=torch.bfloat16 ) flow_shift = 3.0
pipe.scheduler = UniPCMultistepScheduler.from_config( pipe.scheduler.config, flow_shift=flow_shift )

pipe.enable_model_cpu_offload()

prompt = "A robot standing on a mountain top. The sun is setting in the background" negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" video = load_video( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" ) output = pipe( video=video, prompt=prompt, negative_prompt=negative_prompt, height=480, width=512, guidance_scale=7.0, strength=0.7, ).frames[0]

export_to_video(output, "wan-v2v.mp4", fps=16)

Memory Optimizations for Wan 2.1

Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We’ll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.

We’ll use Wan-AI/Wan2.1-I2V-14B-720P-Diffusers model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.

Group Offloading the Transformer and UMT5 Text Encoder

Find more information about group offloading here

Block Level Group Offloading

We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the WanTransformer3DModel and UMT5EncoderModel. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we’ll apply block_level offloading, which will group the modules in a model into blocks of size num_blocks_per_group and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the num_blocks_per_group.

The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.

import torch import numpy as np from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline from diffusers.hooks.group_offloading import apply_group_offloading from diffusers.utils import export_to_video, load_image from transformers import UMT5EncoderModel, CLIPVisionModel

model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers" image_encoder = CLIPVisionModel.from_pretrained( model_id, subfolder="image_encoder", torch_dtype=torch.float32 )

text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)

onload_device = torch.device("cuda") offload_device = torch.device("cpu")

apply_group_offloading(text_encoder, onload_device=onload_device, offload_device=offload_device, offload_type="block_level", num_blocks_per_group=4 )

transformer.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="block_level", num_blocks_per_group=4, ) pipe = WanImageToVideoPipeline.from_pretrained( model_id, vae=vae, transformer=transformer, text_encoder=text_encoder, image_encoder=image_encoder, torch_dtype=torch.bfloat16 )

pipe.to("cuda")

image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" )

max_area = 720 * 832 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height))

prompt = ( "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ) negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

num_frames = 33

output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames, guidance_scale=5.0, ).frames[0]

export_to_video(output, "wan-i2v.mp4", fps=16)

Block Level Group Offloading with CUDA Streams

We can speed up group offloading inference, by enabling the use of CUDA streams. However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.

In the following example we will use CUDA streams when group offloading the WanTransformer3DModel. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.

import torch import numpy as np from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline from diffusers.hooks.group_offloading import apply_group_offloading from diffusers.utils import export_to_video, load_image from transformers import UMT5EncoderModel, CLIPVisionModel

model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers" image_encoder = CLIPVisionModel.from_pretrained( model_id, subfolder="image_encoder", torch_dtype=torch.float32 )

text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)

onload_device = torch.device("cuda") offload_device = torch.device("cpu")

apply_group_offloading(text_encoder, onload_device=onload_device, offload_device=offload_device, offload_type="block_level", num_blocks_per_group=4 )

transformer.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True ) pipe = WanImageToVideoPipeline.from_pretrained( model_id, vae=vae, transformer=transformer, text_encoder=text_encoder, image_encoder=image_encoder, torch_dtype=torch.bfloat16 )

pipe.to("cuda")

image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" )

max_area = 720 * 832 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height))

prompt = ( "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ) negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

num_frames = 33

output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames, guidance_scale=5.0, ).frames[0]

export_to_video(output, "wan-i2v.mp4", fps=16)

Applying Layerwise Casting to the Transformer

Find more information about layerwise casting here

In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer’s weights to torch.float8_e4m3fn, temporarily upcast to torch.bfloat16 during the forward pass of the layer, then revert to torch.float8_e4m3fn afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.

This example will require 20GB of VRAM.

import torch import numpy as np from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline from diffusers.hooks.group_offloading import apply_group_offloading from diffusers.utils import export_to_video, load_image from transformers import UMT5EncoderModel, CLIPVisionModel

model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers" image_encoder = CLIPVisionModel.from_pretrained( model_id, subfolder="image_encoder", torch_dtype=torch.float32 ) text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)

transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

pipe = WanImageToVideoPipeline.from_pretrained( model_id, vae=vae, transformer=transformer, text_encoder=text_encoder, image_encoder=image_encoder, torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")

max_area = 720 * 832 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height)) prompt = ( "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ) negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" num_frames = 33

output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames, num_inference_steps=50, guidance_scale=5.0, ).frames[0] export_to_video(output, "wan-i2v.mp4", fps=16)

Using a Custom Scheduler

Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) scheduler. You can use a different scheduler as follows:

from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline

scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0) scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)

pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=)

pipe.scheduler =

Using Single File Loading with Wan 2.1

The WanTransformer3DModel and AutoencoderKLWan models support loading checkpoints in their original format via the from_single_file loading method.

import torch from diffusers import WanPipeline, WanTransformer3DModel

ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors" transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)

pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)

Recommendations for Inference

WanPipeline

class diffusers.WanPipeline

< source >

( tokenizer: AutoTokenizer text_encoder: UMT5EncoderModel transformer: WanTransformer3DModel vae: AutoencoderKLWan scheduler: FlowMatchEulerDiscreteScheduler )

Parameters

Pipeline for text-to-video generation using Wan.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: int = 480 width: int = 832 num_frames: int = 81 num_inference_steps: int = 50 guidance_scale: float = 5.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'np' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~WanPipelineOutput or tuple

Parameters

Returns

~WanPipelineOutput or tuple

If return_dict is True, WanPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images and the second element is a list of bools indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.

The call function to the pipeline for generation.

Examples:

import torch from diffusers.utils import export_to_video from diffusers import AutoencoderKLWan, WanPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler

model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) flow_shift = 5.0
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) pipe.to("cuda")

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=720, ... width=1280, ... num_frames=81, ... guidance_scale=5.0, ... ).frames[0] export_to_video(output, "output.mp4", fps=16)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None max_sequence_length: int = 226 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

WanImageToVideoPipeline

class diffusers.WanImageToVideoPipeline

< source >

( tokenizer: AutoTokenizer text_encoder: UMT5EncoderModel image_encoder: CLIPVisionModel image_processor: CLIPImageProcessor transformer: WanTransformer3DModel vae: AutoencoderKLWan scheduler: FlowMatchEulerDiscreteScheduler )

Parameters

Pipeline for image-to-video generation using Wan.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: int = 480 width: int = 832 num_frames: int = 81 num_inference_steps: int = 50 guidance_scale: float = 5.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None image_embeds: typing.Optional[torch.Tensor] = None last_image: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'np' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~WanPipelineOutput or tuple

Parameters

Returns

~WanPipelineOutput or tuple

If return_dict is True, WanPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images and the second element is a list of bools indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.

The call function to the pipeline for generation.

Examples:

import torch import numpy as np from diffusers import AutoencoderKLWan, WanImageToVideoPipeline from diffusers.utils import export_to_video, load_image from transformers import CLIPVisionModel

model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" image_encoder = CLIPVisionModel.from_pretrained( ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 ... ) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanImageToVideoPipeline.from_pretrained( ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 ... ) pipe.to("cuda")

image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" ... ) max_area = 480 * 832 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height)) prompt = ( ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ... ) negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe( ... image=image, ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=height, ... width=width, ... num_frames=81, ... guidance_scale=5.0, ... ).frames[0] export_to_video(output, "output.mp4", fps=16)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None max_sequence_length: int = 226 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

WanPipelineOutput

class diffusers.pipelines.wan.pipeline_output.WanPipelineOutput

< source >

( frames: Tensor )

Parameters

Output class for Wan pipelines.

< > Update on GitHub