LTX-Video (original) (raw)

LoRA MPS

LTX-Video is a diffusion transformer designed for fast and real-time generation of high-resolution videos from text and images. The main feature of LTX-Video is the Video-VAE. The Video-VAE has a higher pixel to latent compression ratio (1:192) which enables more efficient video data processing and faster generation speed. To support and prevent finer details from being lost during generation, the Video-VAE decoder performs the latent to pixel conversion and the last denoising step.

You can find all the original LTX-Video checkpoints under the Lightricks organization.

Click on the LTX-Video models in the right sidebar for more examples of other video generation tasks.

The example below demonstrates how to generate a video optimized for memory or inference speed.

Refer to the Reduce memory usage guide for more details about the various memory saving techniques.

The LTX-Video model below requires ~10GB of VRAM.

import torch from diffusers import LTXPipeline, AutoModel from diffusers.hooks import apply_group_offloading from diffusers.utils import export_to_video

transformer = AutoModel.from_pretrained( "Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16 ) transformer.enable_layerwise_casting( storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16 )

pipeline = LTXPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, torch_dtype=torch.bfloat16)

onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) apply_group_offloading(pipeline.vae, onload_device=onload_device, offload_type="leaf_level")

prompt = """ A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage """ negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipeline( prompt=prompt, negative_prompt=negative_prompt, width=768, height=512, num_frames=161, decode_timestep=0.03, decode_noise_scale=0.025, num_inference_steps=50, ).frames[0] export_to_video(video, "output.mp4", fps=24)

Notes

LTXPipeline

class diffusers.LTXPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLLTXVideo text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: LTXVideoTransformer3DModel )

Parameters

Pipeline for text-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 704 num_frames: int = 161 frame_rate: int = 25 num_inference_steps: int = 50 timesteps: typing.List[int] = None guidance_scale: float = 3 guidance_rescale: float = 0.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 128 ) → ~pipelines.ltx.LTXPipelineOutput or tuple

Parameters

Returns

~pipelines.ltx.LTXPipelineOutput or tuple

If return_dict is True, ~pipelines.ltx.LTXPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import LTXPipeline from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) pipe.to("cuda")

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=704, ... height=480, ... num_frames=161, ... num_inference_steps=50, ... ).frames[0] export_to_video(video, "output.mp4", fps=24)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 128 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

LTXImageToVideoPipeline

class diffusers.LTXImageToVideoPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLLTXVideo text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: LTXVideoTransformer3DModel )

Parameters

Pipeline for image-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 704 num_frames: int = 161 frame_rate: int = 25 num_inference_steps: int = 50 timesteps: typing.List[int] = None guidance_scale: float = 3 guidance_rescale: float = 0.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 128 ) → ~pipelines.ltx.LTXPipelineOutput or tuple

Parameters

Returns

~pipelines.ltx.LTXPipelineOutput or tuple

If return_dict is True, ~pipelines.ltx.LTXPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import LTXImageToVideoPipeline from diffusers.utils import export_to_video, load_image

pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) pipe.to("cuda")

image = load_image( ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" ... ) prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe( ... image=image, ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=704, ... height=480, ... num_frames=161, ... num_inference_steps=50, ... ).frames[0] export_to_video(video, "output.mp4", fps=24)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 128 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

LTXConditionPipeline

class diffusers.LTXConditionPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLLTXVideo text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: LTXVideoTransformer3DModel )

Parameters

Pipeline for text/image/video-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

__call__

< source >

( conditions: typing.Union[diffusers.pipelines.ltx.pipeline_ltx_condition.LTXVideoCondition, typing.List[diffusers.pipelines.ltx.pipeline_ltx_condition.LTXVideoCondition]] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], typing.List[typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]]]] = None video: typing.List[typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]]] = None frame_index: typing.Union[int, typing.List[int]] = 0 strength: typing.Union[float, typing.List[float]] = 1.0 denoise_strength: float = 1.0 prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 704 num_frames: int = 161 frame_rate: int = 25 num_inference_steps: int = 50 timesteps: typing.List[int] = None guidance_scale: float = 3 guidance_rescale: float = 0.0 image_cond_noise_scale: float = 0.15 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 256 ) → ~pipelines.ltx.LTXPipelineOutput or tuple

Parameters

Returns

~pipelines.ltx.LTXPipelineOutput or tuple

If return_dict is True, ~pipelines.ltx.LTXPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition from diffusers.utils import export_to_video, load_video, load_image

pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16) pipe.to("cuda")

video = load_video( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" ... ) image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" ... )

condition1 = LTXVideoCondition( ... image=image, ... frame_index=0, ... ) condition2 = LTXVideoCondition( ... video=video, ... frame_index=80, ... )

prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

generator = torch.Generator("cuda").manual_seed(0)

video = pipe( ... conditions=[condition1, condition2], ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=768, ... height=512, ... num_frames=161, ... num_inference_steps=40, ... generator=generator, ... ).frames[0]

export_to_video(video, "output.mp4", fps=24)

add_noise_to_image_conditioning_latents

< source >

( t: float init_latents: Tensor latents: Tensor noise_scale: float conditioning_mask: Tensor generator eps = 1e-06 )

Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially when conditioned on a single frame.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 256 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

trim_conditioning_sequence

< source >

( start_frame: int sequence_num_frames: int target_num_frames: int ) → int

Parameters

updated sequence length

Trim a conditioning sequence to the allowed number of frames.

LTXLatentUpsamplePipeline

class diffusers.LTXLatentUpsamplePipeline

< source >

( vae: AutoencoderKLLTXVideo latent_upsampler: LTXLatentUpsamplerModel )

__call__

< source >

( video: typing.Optional[typing.List[typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]]]] = None height: int = 512 width: int = 704 latents: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None adain_factor: float = 0.0 tone_map_compression_ratio: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True )

adain_filter_latent

< source >

( latents: Tensor reference_latents: Tensor factor: float = 1.0 ) → torch.Tensor

Parameters

The transformed latent tensor

Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent tensor.

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

tone_map_latents

< source >

( latents: Tensor compression: float )

Parameters

Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually smooth way using a sigmoid-based compression.

This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially when controlling dynamic behavior with a compression factor.

LTXPipelineOutput

class diffusers.pipelines.ltx.pipeline_output.LTXPipelineOutput

< source >

( frames: Tensor )

Parameters

Output class for LTX pipelines.

Update on GitHub