Allegro (original) (raw)

Allegro: Open the Black Box of Commercial-Level Video Generation Model from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.

The abstract from the paper is:

Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized AllegroPipeline for inference with bitsandbytes.

import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AllegroTransformer3DModel, AllegroPipeline from diffusers.utils import export_to_video from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel

quant_config = BitsAndBytesConfig(load_in_8bit=True) text_encoder_8bit = T5EncoderModel.from_pretrained( "rhymes-ai/Allegro", subfolder="text_encoder", quantization_config=quant_config, torch_dtype=torch.float16, )

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) transformer_8bit = AllegroTransformer3DModel.from_pretrained( "rhymes-ai/Allegro", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.float16, )

pipeline = AllegroPipeline.from_pretrained( "rhymes-ai/Allegro", text_encoder=text_encoder_8bit, transformer=transformer_8bit, torch_dtype=torch.float16, device_map="balanced", )

prompt = ( "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " "location might be a popular spot for docking fishing boats." ) video = pipeline(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] export_to_video(video, "harbor.mp4", fps=15)

AllegroPipeline

class diffusers.AllegroPipeline

< source >

( tokenizer: T5Tokenizer text_encoder: T5EncoderModel vae: AutoencoderKLAllegro transformer: AllegroTransformer3DModel scheduler: KarrasDiffusionSchedulers )

Parameters

Pipeline for text-to-video generation using Allegro.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' num_inference_steps: int = 100 timesteps: typing.List[int] = None guidance_scale: float = 7.5 num_frames: typing.Optional[int] = None height: typing.Optional[int] = None width: typing.Optional[int] = None num_videos_per_prompt: int = 1 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] clean_caption: bool = True max_sequence_length: int = 512 ) → AllegroPipelineOutput or tuple

Parameters

If return_dict is True, AllegroPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated videos.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import AutoencoderKLAllegro, AllegroPipeline from diffusers.utils import export_to_video

vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") pipe.enable_vae_tiling()

prompt = ( ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " ... "location might be a popular spot for docking fishing boats." ... ) video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] export_to_video(video, "output.mp4", fps=15)

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] do_classifier_free_guidance: bool = True negative_prompt: str = '' num_videos_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 512 **kwargs )

Parameters

Encodes the prompt into text encoder hidden states.

AllegroPipelineOutput

class diffusers.pipelines.allegro.pipeline_output.AllegroPipelineOutput

< source >

( frames: typing.Union[torch.Tensor, numpy.ndarray, typing.List[typing.List[PIL.Image.Image]]] )

Parameters

Output class for Allegro pipelines.

< > Update on GitHub