LTX Video (original) (raw)

LoRA MPS

LTX Video is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

Available models:

Model name Recommended dtype
LTX Video 0.9.0 torch.bfloat16
LTX Video 0.9.1 torch.bfloat16
LTX Video 0.9.5 torch.bfloat16

Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either torch.float32, torch.bfloat16 or torch.float16 but the recommended dtype is torch.bfloat16 as used in the original repository.

Loading Single Files

Loading the original LTX Video checkpoints is also possible with ~ModelMixin.from_single_file. We recommend using from_single_file for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.

import torch from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" transformer = LTXVideoTransformer3DModel.from_single_file( single_file_url, torch_dtype=torch.bfloat16 ) vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16) pipe = LTXImageToVideoPipeline.from_pretrained( "Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16 )

Alternatively, the pipeline can be used to load the weights with ~FromSingleFileMixin.from_single_file.

import torch from diffusers import LTXImageToVideoPipeline from transformers import T5EncoderModel, T5Tokenizer

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" text_encoder = T5EncoderModel.from_pretrained( "Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16 ) tokenizer = T5Tokenizer.from_pretrained( "Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16 ) pipe = LTXImageToVideoPipeline.from_single_file( single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16 )

Loading LTX GGUF checkpoints are also supported:

import torch from diffusers.utils import export_to_video from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig

ckpt_path = ( "https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf" ) transformer = LTXVideoTransformer3DModel.from_single_file( ckpt_path, quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), torch_dtype=torch.bfloat16, ) pipe = LTXPipeline.from_pretrained( "Lightricks/LTX-Video", transformer=transformer, torch_dtype=torch.bfloat16, ) pipe.enable_model_cpu_offload()

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe( prompt=prompt, negative_prompt=negative_prompt, width=704, height=480, num_frames=161, num_inference_steps=50, ).frames[0] export_to_video(video, "output_gguf_ltx.mp4", fps=24)

Make sure to read the documentation on GGUF to learn more about our GGUF support.

Loading and running inference with LTX Video 0.9.1 weights.

import torch from diffusers import LTXPipeline from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16) pipe.to("cuda")

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe( prompt=prompt, negative_prompt=negative_prompt, width=768, height=512, num_frames=161, decode_timestep=0.03, decode_noise_scale=0.025, num_inference_steps=50, ).frames[0] export_to_video(video, "output.mp4", fps=24)

Refer to this section to learn more about optimizing memory consumption.

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized LTXPipeline for inference with bitsandbytes.

import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LTXVideoTransformer3DModel, LTXPipeline from diffusers.utils import export_to_video from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel

quant_config = BitsAndBytesConfig(load_in_8bit=True) text_encoder_8bit = T5EncoderModel.from_pretrained( "Lightricks/LTX-Video", subfolder="text_encoder", quantization_config=quant_config, torch_dtype=torch.float16, )

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) transformer_8bit = LTXVideoTransformer3DModel.from_pretrained( "Lightricks/LTX-Video", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.float16, )

pipeline = LTXPipeline.from_pretrained( "Lightricks/LTX-Video", text_encoder=text_encoder_8bit, transformer=transformer_8bit, torch_dtype=torch.float16, device_map="balanced", )

prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." video = pipeline(prompt=prompt, num_frames=161, num_inference_steps=50).frames[0] export_to_video(video, "ship.mp4", fps=24)

LTXPipeline

class diffusers.LTXPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLLTXVideo text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: LTXVideoTransformer3DModel )

Parameters

Pipeline for text-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 704 num_frames: int = 161 frame_rate: int = 25 num_inference_steps: int = 50 timesteps: typing.List[int] = None guidance_scale: float = 3 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 128 ) → ~pipelines.ltx.LTXPipelineOutput or tuple

Parameters

Returns

~pipelines.ltx.LTXPipelineOutput or tuple

If return_dict is True, ~pipelines.ltx.LTXPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import LTXPipeline from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) pipe.to("cuda")

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=704, ... height=480, ... num_frames=161, ... num_inference_steps=50, ... ).frames[0] export_to_video(video, "output.mp4", fps=24)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 128 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

LTXImageToVideoPipeline

class diffusers.LTXImageToVideoPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLLTXVideo text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: LTXVideoTransformer3DModel )

Parameters

Pipeline for image-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 704 num_frames: int = 161 frame_rate: int = 25 num_inference_steps: int = 50 timesteps: typing.List[int] = None guidance_scale: float = 3 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 128 ) → ~pipelines.ltx.LTXPipelineOutput or tuple

Parameters

Returns

~pipelines.ltx.LTXPipelineOutput or tuple

If return_dict is True, ~pipelines.ltx.LTXPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import LTXImageToVideoPipeline from diffusers.utils import export_to_video, load_image

pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) pipe.to("cuda")

image = load_image( ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" ... ) prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe( ... image=image, ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=704, ... height=480, ... num_frames=161, ... num_inference_steps=50, ... ).frames[0] export_to_video(video, "output.mp4", fps=24)

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 128 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

LTXConditionPipeline

class diffusers.LTXConditionPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLLTXVideo text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: LTXVideoTransformer3DModel )

Parameters

Pipeline for text/image/video-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

__call__

< source >

( conditions: typing.Union[diffusers.pipelines.ltx.pipeline_ltx_condition.LTXVideoCondition, typing.List[diffusers.pipelines.ltx.pipeline_ltx_condition.LTXVideoCondition]] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], typing.List[typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]]]] = None video: typing.List[typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]]] = None frame_index: typing.Union[int, typing.List[int]] = 0 strength: typing.Union[float, typing.List[float]] = 1.0 prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 704 num_frames: int = 161 frame_rate: int = 25 num_inference_steps: int = 50 timesteps: typing.List[int] = None guidance_scale: float = 3 image_cond_noise_scale: float = 0.15 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None decode_timestep: typing.Union[float, typing.List[float]] = 0.0 decode_noise_scale: typing.Union[float, typing.List[float], NoneType] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 256 ) → ~pipelines.ltx.LTXPipelineOutput or tuple

Parameters

Returns

~pipelines.ltx.LTXPipelineOutput or tuple

If return_dict is True, ~pipelines.ltx.LTXPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition from diffusers.utils import export_to_video, load_video, load_image

pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16) pipe.to("cuda")

video = load_video( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" ... ) image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" ... )

condition1 = LTXVideoCondition( ... image=image, ... frame_index=0, ... ) condition2 = LTXVideoCondition( ... video=video, ... frame_index=80, ... )

prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

generator = torch.Generator("cuda").manual_seed(0)

video = pipe( ... conditions=[condition1, condition2], ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=768, ... height=512, ... num_frames=161, ... num_inference_steps=40, ... generator=generator, ... ).frames[0]

export_to_video(video, "output.mp4", fps=24)

add_noise_to_image_conditioning_latents

< source >

( t: float init_latents: Tensor latents: Tensor noise_scale: float conditioning_mask: Tensor generator eps = 1e-06 )

Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially when conditioned on a single frame.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 256 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )

Parameters

Encodes the prompt into text encoder hidden states.

trim_conditioning_sequence

< source >

( start_frame: int sequence_num_frames: int target_num_frames: int ) → int

Parameters

updated sequence length

Trim a conditioning sequence to the allowed number of frames.

LTXPipelineOutput

class diffusers.pipelines.ltx.pipeline_output.LTXPipelineOutput

< source >

( frames: Tensor )

Parameters

Output class for LTX pipelines.

< > Update on GitHub