AuraFlow (original) (raw)

AuraFlow is inspired by Stable Diffusion 3 and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the GenEval benchmark.

It was developed by the Fal team and more details about it can be found in this blog post.

AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out this section for more details.

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized AuraFlowPipeline for inference with bitsandbytes.

import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AuraFlowTransformer2DModel, AuraFlowPipeline from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel

quant_config = BitsAndBytesConfig(load_in_8bit=True) text_encoder_8bit = T5EncoderModel.from_pretrained( "fal/AuraFlow", subfolder="text_encoder", quantization_config=quant_config, torch_dtype=torch.float16, )

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) transformer_8bit = AuraFlowTransformer2DModel.from_pretrained( "fal/AuraFlow", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.float16, )

pipeline = AuraFlowPipeline.from_pretrained( "fal/AuraFlow", text_encoder=text_encoder_8bit, transformer=transformer_8bit, torch_dtype=torch.float16, device_map="balanced", )

prompt = "a tiny astronaut hatching from an egg on the moon" image = pipeline(prompt).images[0] image.save("auraflow.png")

Loading GGUF checkpoints are also supported:

import torch from diffusers import ( AuraFlowPipeline, GGUFQuantizationConfig, AuraFlowTransformer2DModel, )

transformer = AuraFlowTransformer2DModel.from_single_file( "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf", quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), torch_dtype=torch.bfloat16, )

pipeline = AuraFlowPipeline.from_pretrained( "fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=torch.bfloat16, )

prompt = "a cute pony in a field of flowers" image = pipeline(prompt).images[0] image.save("auraflow.png")

Support for torch.compile()

AuraFlow can be compiled with torch.compile() to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from here. The snippet below shows the changes needed to enable this:

Specifying use_duck_shape to be False instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this comment.

This enables from 100% (on low resolutions) to a 30% (on 1536x1536 resolution) speed improvements.

Thanks to AstraliteHeart who helped us rewrite the AuraFlowTransformer2DModel class so that the above works for different resolutions (PR).

AuraFlowPipeline

class diffusers.AuraFlowPipeline

< source >

( tokenizer: T5Tokenizer text_encoder: UMT5EncoderModel vae: AutoencoderKL transformer: AuraFlowTransformer2DModel scheduler: FlowMatchEulerDiscreteScheduler )

Parameters

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None num_inference_steps: int = 50 sigmas: typing.List[float] = None guidance_scale: float = 3.5 num_images_per_prompt: typing.Optional[int] = 1 height: typing.Optional[int] = 1024 width: typing.Optional[int] = 1024 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 256 output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] )

Parameters

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import AuraFlowPipeline

pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "A cat holding a sign that says hello world" image = pipe(prompt).images[0] image.save("aura_flow.png")

Returns: ImagePipelineOutput or tuple: If return_dict is True, ImagePipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str]] = None do_classifier_free_guidance: bool = True num_images_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 256 lora_scale: typing.Optional[float] = None )

Parameters

Encodes the prompt into text encoder hidden states.

< > Update on GitHub