PixArt-α (original) (raw)

PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.

The abstract from the paper is:

The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α’s training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5’s training time (675 vs. 6,250 A100 GPU days), saving nearly 300,000(300,000 (300,000(26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.

You can find the original codebase at PixArt-alpha/PixArt-alpha and all the available checkpoints at PixArt-alpha.

Some notes about this pipeline:

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

Inference with under 8GB GPU VRAM

Run the PixArtAlphaPipeline with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let’s walk through a full-fledged example.

First, install the bitsandbytes library:

pip install -U bitsandbytes

Then load the text encoder in 8-bit:

from transformers import T5EncoderModel from diffusers import PixArtAlphaPipeline import torch

text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder", load_in_8bit=True, device_map="auto",

) pipe = PixArtAlphaPipeline.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", text_encoder=text_encoder, transformer=None, device_map="auto" )

Now, use the pipe to encode a prompt:

with torch.no_grad(): prompt = "cute cat" prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

Since text embeddings have been computed, remove the text_encoder and pipe from the memory, and free up some GPU VRAM:

import gc

def flush(): gc.collect() torch.cuda.empty_cache()

del text_encoder del pipe flush()

Then compute the latents with the prompt embeddings as inputs:

pipe = PixArtAlphaPipeline.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", text_encoder=None, torch_dtype=torch.float16, ).to("cuda")

latents = pipe( negative_prompt=None, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, num_images_per_prompt=1, output_type="latent", ).images

del pipe.transformer flush()

Notice that while initializing pipe, you’re setting text_encoder to None so that it’s not loaded.

Once the latents are computed, pass it off to the VAE to decode into a real image:

with torch.no_grad(): image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] image = pipe.image_processor.postprocess(image, output_type="pil")[0] image.save("cat.png")

By deleting components you aren’t using and flushing the GPU VRAM, you should be able to run PixArtAlphaPipeline with under 8GB GPU VRAM.

If you want a report of your memory-usage, run this script.

Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It’s recommended to compare the outputs with and without 8-bit.

While loading the text_encoder, you set load_in_8bit to True. You could also specify load_in_4bit to bring your memory requirements down even further to under 7GB.

PixArtAlphaPipeline

class diffusers.PixArtAlphaPipeline

< source >

( tokenizer: T5Tokenizer text_encoder: T5EncoderModel vae: AutoencoderKL transformer: PixArtTransformer2DModel scheduler: DPMSolverMultistepScheduler )

Parameters

Pipeline for text-to-image generation using PixArt-Alpha.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' num_inference_steps: int = 20 timesteps: typing.List[int] = None sigmas: typing.List[float] = None guidance_scale: float = 4.5 num_images_per_prompt: typing.Optional[int] = 1 height: typing.Optional[int] = None width: typing.Optional[int] = None eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 clean_caption: bool = True use_resolution_binning: bool = True max_sequence_length: int = 120 **kwargs ) → ImagePipelineOutput or tuple

Parameters

If return_dict is True, ImagePipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import PixArtAlphaPipeline

pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)

pipe.enable_model_cpu_offload()

prompt = "A small cactus with a happy face in the Sahara desert." image = pipe(prompt).images[0]

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] do_classifier_free_guidance: bool = True negative_prompt: str = '' num_images_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 120 **kwargs )

Parameters

Encodes the prompt into text encoder hidden states.

< > Update on GitHub