Stable unCLIP (original) (raw)

Stable unCLIP checkpoints are finetuned from Stable Diffusion 2.1 checkpoints to condition on CLIP image embeddings. Stable unCLIP still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation.

The abstract from the paper is:

Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.

Tips

Stable unCLIP takes noise_level as input during inference which determines how much noise is added to the image embeddings. A higher noise_level increases variation in the final un-noised images. By default, we do not add any additional noise to the image embeddings (noise_level = 0).

Text-to-Image Generation

Stable unCLIP can be leveraged for text-to-image generation by pipelining it with the prior model of KakaoBrain’s open source DALL-E 2 replication Karlo:

import torch from diffusers import UnCLIPScheduler, DDPMScheduler, StableUnCLIPPipeline from diffusers.models import PriorTransformer from transformers import CLIPTokenizer, CLIPTextModelWithProjection

prior_model_id = "kakaobrain/karlo-v1-alpha" data_type = torch.float16 prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type)

prior_text_model_id = "openai/clip-vit-large-patch14" prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id) prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type) prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler") prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)

stable_unclip_model_id = "stabilityai/stable-diffusion-2-1-unclip-small"

pipe = StableUnCLIPPipeline.from_pretrained( stable_unclip_model_id, torch_dtype=data_type, variant="fp16", prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_model, prior=prior, prior_scheduler=prior_scheduler, )

pipe = pipe.to("cuda") wave_prompt = "dramatic wave, the Oceans roar, Strong wave spiral across the oceans as the waves unfurl into roaring crests; perfect wave form; perfect wave shape; dramatic wave shape; wave shape unbelievable; wave; wave shape spectacular"

image = pipe(prompt=wave_prompt).images[0] image

For text-to-image we use stabilityai/stable-diffusion-2-1-unclip-small as it was trained on CLIP ViT-L/14 embedding, the same as the Karlo model prior. stabilityai/stable-diffusion-2-1-unclip was trained on OpenCLIP ViT-H, so we don’t recommend its use.

Text guided Image-to-Image Variation

from diffusers import StableUnCLIPImg2ImgPipeline from diffusers.utils import load_image import torch

pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" ) pipe = pipe.to("cuda")

url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" init_image = load_image(url)

images = pipe(init_image).images images[0].save("variation_image.png")

Optionally, you can also pass a prompt to pipe such as:

prompt = "A fantasy landscape, trending on artstation"

image = pipe(init_image, prompt=prompt).images[0] image

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

StableUnCLIPPipeline

class diffusers.StableUnCLIPPipeline

< source >

( prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModelWithProjection prior: PriorTransformer prior_scheduler: KarrasDiffusionSchedulers image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL )

Parameters

Pipeline for text-to-image generation using stable unCLIP.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

The pipeline also inherits the following loading methods:

__call__

< source >

( prompt: Union = None height: Optional = None width: Optional = None num_inference_steps: int = 20 guidance_scale: float = 10.0 negative_prompt: Union = None num_images_per_prompt: Optional = 1 eta: float = 0.0 generator: Optional = None latents: Optional = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None output_type: Optional = 'pil' return_dict: bool = True callback: Optional = None callback_steps: int = 1 cross_attention_kwargs: Optional = None noise_level: int = 0 prior_num_inference_steps: int = 25 prior_guidance_scale: float = 4.0 prior_latents: Optional = None clip_skip: Optional = None ) → ImagePipelineOutput or tuple

Parameters

~ pipeline_utils.ImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

The call function to the pipeline for generation.

Examples:

import torch from diffusers import StableUnCLIPPipeline

pipe = StableUnCLIPPipeline.from_pretrained( ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 ... )
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars" images = pipe(prompt).images images[0].save("astronaut_horse.png")

enable_attention_slicing

< source >

( slice_size: Union = 'auto' )

Parameters

Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. For more than one attention head, the computation is performed sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.

⚠️ Don’t enable attention slicing if you’re already using scaled_dot_product_attention (SDPA) from PyTorch 2.0 or xFormers. These attention computations are already very memory efficient so you won’t need to enable this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!

Examples:

import torch from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", ... torch_dtype=torch.float16, ... use_safetensors=True, ... )

prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() image = pipe(prompt).images[0]

Disable sliced attention computation. If enable_attention_slicing was previously called, attention is computed in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

enable_xformers_memory_efficient_attention

< source >

( attention_op: Optional = None )

Parameters

Enable memory efficient attention from xFormers. When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed.

⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent.

Examples:

import torch from diffusers import DiffusionPipeline from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) pipe = pipe.to("cuda") pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)

pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

disable_xformers_memory_efficient_attention

< source >

( )

Disable memory efficient attention from xFormers.

encode_prompt

< source >

( prompt device num_images_per_prompt do_classifier_free_guidance negative_prompt = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None lora_scale: Optional = None clip_skip: Optional = None )

Parameters

Encodes the prompt into text encoder hidden states.

noise_image_embeddings

< source >

( image_embeds: Tensor noise_level: int noise: Optional = None generator: Optional = None )

Add noise to the image embeddings. The amount of noise is controlled by a noise_level input. A highernoise_level increases the variance in the final un-noised images.

The noise is applied in two ways:

  1. A noise schedule is applied directly to the embeddings.
  2. A vector of sinusoidal time embeddings are appended to the output.

In both cases, the amount of noise is controlled by the same noise_level.

The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.

StableUnCLIPImg2ImgPipeline

class diffusers.StableUnCLIPImg2ImgPipeline

< source >

( feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL )

Parameters

Pipeline for text-guided image-to-image generation using stable unCLIP.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

The pipeline also inherits the following loading methods:

__call__

< source >

( image: Union = None prompt: Union = None height: Optional = None width: Optional = None num_inference_steps: int = 20 guidance_scale: float = 10 negative_prompt: Union = None num_images_per_prompt: Optional = 1 eta: float = 0.0 generator: Optional = None latents: Optional = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None output_type: Optional = 'pil' return_dict: bool = True callback: Optional = None callback_steps: int = 1 cross_attention_kwargs: Optional = None noise_level: int = 0 image_embeds: Optional = None clip_skip: Optional = None ) → ImagePipelineOutput or tuple

Parameters

~ pipeline_utils.ImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

The call function to the pipeline for generation.

Examples:

import requests import torch from PIL import Image from io import BytesIO

from diffusers import StableUnCLIPImg2ImgPipeline

pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=torch.float16 ... ) pipe = pipe.to("cuda")

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((768, 512))

prompt = "A fantasy landscape, trending on artstation"

images = pipe(init_image, prompt).images images[0].save("fantasy_landscape.png")

enable_attention_slicing

< source >

( slice_size: Union = 'auto' )

Parameters

Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. For more than one attention head, the computation is performed sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.

⚠️ Don’t enable attention slicing if you’re already using scaled_dot_product_attention (SDPA) from PyTorch 2.0 or xFormers. These attention computations are already very memory efficient so you won’t need to enable this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!

Examples:

import torch from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", ... torch_dtype=torch.float16, ... use_safetensors=True, ... )

prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() image = pipe(prompt).images[0]

Disable sliced attention computation. If enable_attention_slicing was previously called, attention is computed in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

enable_xformers_memory_efficient_attention

< source >

( attention_op: Optional = None )

Parameters

Enable memory efficient attention from xFormers. When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed.

⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent.

Examples:

import torch from diffusers import DiffusionPipeline from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) pipe = pipe.to("cuda") pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)

pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

disable_xformers_memory_efficient_attention

< source >

( )

Disable memory efficient attention from xFormers.

encode_prompt

< source >

( prompt device num_images_per_prompt do_classifier_free_guidance negative_prompt = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None lora_scale: Optional = None clip_skip: Optional = None )

Parameters

Encodes the prompt into text encoder hidden states.

noise_image_embeddings

< source >

( image_embeds: Tensor noise_level: int noise: Optional = None generator: Optional = None )

Add noise to the image embeddings. The amount of noise is controlled by a noise_level input. A highernoise_level increases the variance in the final un-noised images.

The noise is applied in two ways:

  1. A noise schedule is applied directly to the embeddings.
  2. A vector of sinusoidal time embeddings are appended to the output.

In both cases, the amount of noise is controlled by the same noise_level.

The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.

ImagePipelineOutput

class diffusers.ImagePipelineOutput

< source >

( images: Union )

Parameters

Output class for image pipelines.

< > Update on GitHub