Würstchen (original) (raw)

This pipeline is deprecated but it can still be used. However, we won’t test the pipeline anymore and won’t accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.

LoRA

Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models is by Pablo Pernias, Dominic Rampas, Mats L. Richter and Christopher Pal and Marc Aubreville.

The abstract from the paper is:

We introduce Würstchen, a novel architecture for text-to-image synthesis that combines competitive performance with unprecedented cost-effectiveness for large-scale text-to-image diffusion models. A key contribution of our work is to develop a latent diffusion technique in which we learn a detailed but extremely compact semantic image representation used to guide the diffusion process. This highly compressed representation of an image provides much more detailed guidance compared to latent representations of language and this significantly reduces the computational requirements to achieve state-of-the-art results. Our approach also improves the quality of text-conditioned image generation based on our user preference study. The training requirements of our approach consists of 24,602 A100-GPU hours - compared to Stable Diffusion 2.1’s 200,000 GPU hours. Our approach also requires less training data to achieve these results. Furthermore, our compact latent representations allows us to perform inference over twice as fast, slashing the usual costs and carbon footprint of a state-of-the-art (SOTA) diffusion model significantly, without compromising the end performance. In a broader comparison against SOTA models our approach is substantially more efficient and compares favorably in terms of image quality. We believe that this work motivates more emphasis on the prioritization of both performance and computational accessibility.

Würstchen Overview

Würstchen is a diffusion model, whose text-conditional model works in a highly compressed latent space of images. Why is this important? Compressing data can reduce computational costs for both training and inference by magnitudes. Training on 1024x1024 images is way more expensive than training on 32x32. Usually, other works make use of a relatively small compression, in the range of 4x - 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, we achieve a 42x spatial compression. This was unseen before because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details can be found in the paper). A third model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, while also allowing cheaper and faster inference.

Würstchen v2 comes to Diffusers

After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competitive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.

We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:

We recommend using v2-interpolated, as it has a nice touch of both photorealism and aesthetics. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations. A comparison can be seen here:

Text-to-Image Generation

For the sake of usability, Würstchen can be used with a single pipeline. This pipeline can be used as follows:

import torch from diffusers import AutoPipelineForText2Image from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")

caption = "Anthropomorphic cat dressed as a fire fighter" images = pipe( caption, width=1024, height=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=2, ).images

For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the prior_pipeline. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the decoder_pipeline. For more details, take a look at the paper.

import torch from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

device = "cuda" dtype = torch.float16 num_images_per_prompt = 2

prior_pipeline = WuerstchenPriorPipeline.from_pretrained( "warp-ai/wuerstchen-prior", torch_dtype=dtype ).to(device) decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( "warp-ai/wuerstchen", torch_dtype=dtype ).to(device)

caption = "Anthropomorphic cat dressed as a fire fighter" negative_prompt = ""

prior_output = prior_pipeline( prompt=caption, height=1024, width=1536, timesteps=DEFAULT_STAGE_C_TIMESTEPS, negative_prompt=negative_prompt, guidance_scale=4.0, num_images_per_prompt=num_images_per_prompt, ) decoder_output = decoder_pipeline( image_embeddings=prior_output.image_embeddings, prompt=caption, negative_prompt=negative_prompt, guidance_scale=0.0, output_type="pil", ).images[0] decoder_output

Speed-Up Inference

You can make use of torch.compile function and gain a speed-up of about 2-3x:

prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True) decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)

Limitations

The original codebase, as well as experimental ideas, can be found at dome272/Wuerstchen.

WuerstchenCombinedPipeline

class diffusers.WuerstchenCombinedPipeline

< source >

( tokenizer: CLIPTokenizer text_encoder: CLIPTextModel decoder: WuerstchenDiffNeXt scheduler: DDPMWuerstchenScheduler vqgan: PaellaVQModel prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModel prior_prior: WuerstchenPrior prior_scheduler: DDPMWuerstchenScheduler )

Parameters

Combined Pipeline for text-to-image generation using Wuerstchen

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 512 prior_num_inference_steps: int = 60 prior_timesteps: typing.Optional[typing.List[float]] = None prior_guidance_scale: float = 4.0 num_inference_steps: int = 12 decoder_timesteps: typing.Optional[typing.List[float]] = None decoder_guidance_scale: float = 0.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True prior_callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None prior_callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] **kwargs )

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusions import WuerstchenCombinedPipeline

pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to( ... "cuda" ... ) prompt = "an image of a shiba inu, donning a spacesuit and helmet" images = pipe(prompt=prompt)

enable_model_cpu_offload

< source >

( gpu_id: typing.Optional[int] = None device: typing.Union[torch.device, str] = None )

Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to enable_sequential_cpu_offload, this method moves one whole model at a time to the GPU when its forwardmethod is called, and the model remains in GPU until the next model runs. Memory savings are lower than withenable_sequential_cpu_offload, but performance is much better due to the iterative execution of the unet.

enable_sequential_cpu_offload

< source >

( gpu_id: typing.Optional[int] = None device: typing.Union[torch.device, str] = None )

Offloads all models (unet, text_encoder, vae, and safety checker state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a torch.device('meta') and loaded on a GPU only when their specific submodule’s forward method is called. Offloading happens on a submodule basis. Memory savings are higher than using enable_model_cpu_offload, but performance is lower.

WuerstchenPriorPipeline

class diffusers.WuerstchenPriorPipeline

< source >

( tokenizer: CLIPTokenizer text_encoder: CLIPTextModel prior: WuerstchenPrior scheduler: DDPMWuerstchenScheduler latent_mean: float = 42.0 latent_std: float = 1.0 resolution_multiple: float = 42.67 )

Parameters

Pipeline for generating image prior for Wuerstchen.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

The pipeline also inherits the following loading methods:

__call__

< source >

( prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 1024 width: int = 1024 num_inference_steps: int = 60 timesteps: typing.List[float] = None guidance_scale: float = 8.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pt' return_dict: bool = True callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] **kwargs )

Parameters

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import WuerstchenPriorPipeline

prior_pipe = WuerstchenPriorPipeline.from_pretrained( ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 ... ).to("cuda")

prompt = "an image of a shiba inu, donning a spacesuit and helmet" prior_output = pipe(prompt)

WuerstchenPriorPipelineOutput

class diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput

< source >

( image_embeddings: typing.Union[torch.Tensor, numpy.ndarray] )

Parameters

Output class for WuerstchenPriorPipeline.

WuerstchenDecoderPipeline

class diffusers.WuerstchenDecoderPipeline

< source >

( tokenizer: CLIPTokenizer text_encoder: CLIPTextModel decoder: WuerstchenDiffNeXt scheduler: DDPMWuerstchenScheduler vqgan: PaellaVQModel latent_dim_scale: float = 10.67 )

Parameters

Pipeline for generating images from the Wuerstchen model.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( image_embeddings: typing.Union[torch.Tensor, typing.List[torch.Tensor]] prompt: typing.Union[str, typing.List[str]] = None num_inference_steps: int = 12 timesteps: typing.Optional[typing.List[float]] = None guidance_scale: float = 0.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] **kwargs )

Parameters

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline

prior_pipe = WuerstchenPriorPipeline.from_pretrained( ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 ... ).to("cuda") gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to( ... "cuda" ... )

prompt = "an image of a shiba inu, donning a spacesuit and helmet" prior_output = pipe(prompt) images = gen_pipe(prior_output.image_embeddings, prompt=prompt)

Citation

  @misc{pernias2023wuerstchen,
        title={Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models},
        author={Pablo Pernias and Dominic Rampas and Mats L. Richter and Christopher J. Pal and Marc Aubreville},
        year={2023},
        eprint={2306.00637},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
  }

Update on GitHub