Kandinsky 2.1 (original) (raw)

Kandinsky 2.1 is created by Arseniy Shakhmatov, Anton Razzhigaev, Aleksandr Nikolich, Vladimir Arkhipkin, Igor Pavlov, Andrey Kuznetsov, and Denis Dimitrov.

The description from it’s GitHub page is:

Kandinsky 2.1 inherits best practicies from Dall-E 2 and Latent diffusion, while introducing some new ideas. As text and image encoder it uses CLIP model and diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach increases the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.

The original codebase can be found at ai-forever/Kandinsky-2.

Check out the Kandinsky Community organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

KandinskyPriorPipeline

class diffusers.KandinskyPriorPipeline

< source >

( prior: PriorTransformer image_encoder: CLIPVisionModelWithProjection text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer scheduler: UnCLIPScheduler image_processor: CLIPImageProcessor )

Parameters

Pipeline for generating image prior for Kandinsky

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: int = 1 num_inference_steps: int = 25 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None guidance_scale: float = 4.0 output_type: typing.Optional[str] = 'pt' return_dict: bool = True ) → KandinskyPriorPipelineOutput or tuple

Parameters

Returns

KandinskyPriorPipelineOutput or tuple

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import KandinskyPipeline, KandinskyPriorPipeline import torch

pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") pipe_prior.to("cuda")

prompt = "red cat, 4k photo" out = pipe_prior(prompt) image_emb = out.image_embeds negative_image_emb = out.negative_image_embeds

pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") pipe.to("cuda")

image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images

image[0].save("cat.png")

interpolate

< source >

( images_and_prompts: typing.List[typing.Union[str, PIL.Image.Image, torch.Tensor]] weights: typing.List[float] num_images_per_prompt: int = 1 num_inference_steps: int = 25 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None negative_prior_prompt: typing.Optional[str] = None negative_prompt: str = '' guidance_scale: float = 4.0 device = None ) → KandinskyPriorPipelineOutput or tuple

Parameters

Returns

KandinskyPriorPipelineOutput or tuple

Function invoked when using the prior pipeline for interpolation.

Examples:

from diffusers import KandinskyPriorPipeline, KandinskyPipeline from diffusers.utils import load_image import PIL

import torch from torchvision import transforms

pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) pipe_prior.to("cuda")

img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... )

img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... )

images_texts = ["a cat", img1, img2] weights = [0.3, 0.3, 0.4] image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)

pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) pipe.to("cuda")

image = pipe( ... "", ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0]

image.save("starry_cat.png")

KandinskyPipeline

class diffusers.KandinskyPipeline

< source >

( text_encoder: MultilingualCLIP tokenizer: XLMRobertaTokenizer unet: UNet2DConditionModel scheduler: typing.Union[diffusers.schedulers.scheduling_ddim.DDIMScheduler, diffusers.schedulers.scheduling_ddpm.DDPMScheduler] movq: VQModel )

Parameters

Pipeline for text-to-image generation using Kandinsky

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] image_embeds: typing.Union[torch.Tensor, typing.List[torch.Tensor]] negative_image_embeds: typing.Union[torch.Tensor, typing.List[torch.Tensor]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 512 num_inference_steps: int = 100 guidance_scale: float = 4.0 num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 return_dict: bool = True ) → ImagePipelineOutput or tuple

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import KandinskyPipeline, KandinskyPriorPipeline import torch

pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") pipe_prior.to("cuda")

prompt = "red cat, 4k photo" out = pipe_prior(prompt) image_emb = out.image_embeds negative_image_emb = out.negative_image_embeds

pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") pipe.to("cuda")

image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images

image[0].save("cat.png")

KandinskyCombinedPipeline

class diffusers.KandinskyCombinedPipeline

< source >

( text_encoder: MultilingualCLIP tokenizer: XLMRobertaTokenizer unet: UNet2DConditionModel scheduler: typing.Union[diffusers.schedulers.scheduling_ddim.DDIMScheduler, diffusers.schedulers.scheduling_ddpm.DDPMScheduler] movq: VQModel prior_prior: PriorTransformer prior_image_encoder: CLIPVisionModelWithProjection prior_text_encoder: CLIPTextModelWithProjection prior_tokenizer: CLIPTokenizer prior_scheduler: UnCLIPScheduler prior_image_processor: CLIPImageProcessor )

Parameters

Combined Pipeline for text-to-image generation using Kandinsky

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_inference_steps: int = 100 guidance_scale: float = 4.0 num_images_per_prompt: int = 1 height: int = 512 width: int = 512 prior_guidance_scale: float = 4.0 prior_num_inference_steps: int = 25 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 return_dict: bool = True ) → ImagePipelineOutput or tuple

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import AutoPipelineForText2Image import torch

pipe = AutoPipelineForText2Image.from_pretrained( "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload()

prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"

image = pipe(prompt=prompt, num_inference_steps=25).images[0]

enable_sequential_cpu_offload

< source >

( gpu_id: typing.Optional[int] = None device: typing.Union[torch.device, str] = 'cuda' )

Offloads all models (unet, text_encoder, vae, and safety checker state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a torch.device('meta') and loaded on a GPU only when their specific submodule’s forward method is called. Offloading happens on a submodule basis. Memory savings are higher than using enable_model_cpu_offload, but performance is lower.

KandinskyImg2ImgPipeline

class diffusers.KandinskyImg2ImgPipeline

< source >

( text_encoder: MultilingualCLIP movq: VQModel tokenizer: XLMRobertaTokenizer unet: UNet2DConditionModel scheduler: DDIMScheduler )

Parameters

Pipeline for image-to-image generation using Kandinsky

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Union[torch.Tensor, PIL.Image.Image, typing.List[torch.Tensor], typing.List[PIL.Image.Image]] image_embeds: Tensor negative_image_embeds: Tensor negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 512 num_inference_steps: int = 100 strength: float = 0.3 guidance_scale: float = 7.0 num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None output_type: typing.Optional[str] = 'pil' callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 return_dict: bool = True ) → ImagePipelineOutput or tuple

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline from diffusers.utils import load_image import torch

pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) pipe_prior.to("cuda")

prompt = "A red cartoon frog, 4k" image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)

pipe = KandinskyImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ... ) pipe.to("cuda")

init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... )

image = pipe( ... prompt, ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images

image[0].save("red_frog.png")

KandinskyImg2ImgCombinedPipeline

class diffusers.KandinskyImg2ImgCombinedPipeline

< source >

( text_encoder: MultilingualCLIP tokenizer: XLMRobertaTokenizer unet: UNet2DConditionModel scheduler: typing.Union[diffusers.schedulers.scheduling_ddim.DDIMScheduler, diffusers.schedulers.scheduling_ddpm.DDPMScheduler] movq: VQModel prior_prior: PriorTransformer prior_image_encoder: CLIPVisionModelWithProjection prior_text_encoder: CLIPTextModelWithProjection prior_tokenizer: CLIPTokenizer prior_scheduler: UnCLIPScheduler prior_image_processor: CLIPImageProcessor )

Parameters

Combined Pipeline for image-to-image generation using Kandinsky

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Union[torch.Tensor, PIL.Image.Image, typing.List[torch.Tensor], typing.List[PIL.Image.Image]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_inference_steps: int = 100 guidance_scale: float = 4.0 num_images_per_prompt: int = 1 strength: float = 0.3 height: int = 512 width: int = 512 prior_guidance_scale: float = 4.0 prior_num_inference_steps: int = 25 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 return_dict: bool = True ) → ImagePipelineOutput or tuple

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import AutoPipelineForImage2Image import torch import requests from io import BytesIO from PIL import Image import os

pipe = AutoPipelineForImage2Image.from_pretrained( "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload()

prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality"

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") image.thumbnail((768, 768))

image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0]

enable_sequential_cpu_offload

< source >

( gpu_id: typing.Optional[int] = None device: typing.Union[torch.device, str] = 'cuda' )

Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to atorch.device('meta') and loaded to GPU only when their specific submodule has its forwardmethod called. Note that offloading happens on a submodule basis. Memory savings are higher than withenable_model_cpu_offload`, but performance is lower.

KandinskyInpaintPipeline

class diffusers.KandinskyInpaintPipeline

< source >

( text_encoder: MultilingualCLIP movq: VQModel tokenizer: XLMRobertaTokenizer unet: UNet2DConditionModel scheduler: DDIMScheduler )

Parameters

Pipeline for text-guided image inpainting using Kandinsky2.1

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Union[torch.Tensor, PIL.Image.Image] mask_image: typing.Union[torch.Tensor, PIL.Image.Image, numpy.ndarray] image_embeds: Tensor negative_image_embeds: Tensor negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 512 width: int = 512 num_inference_steps: int = 100 guidance_scale: float = 4.0 num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 return_dict: bool = True ) → ImagePipelineOutput or tuple

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline from diffusers.utils import load_image import torch import numpy as np

pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) pipe_prior.to("cuda")

prompt = "a hat" image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)

pipe = KandinskyInpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 ... ) pipe.to("cuda")

init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... )

mask = np.zeros((768, 768), dtype=np.float32) mask[:250, 250:-250] = 1

out = pipe( ... prompt, ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... )

image = out.images[0] image.save("cat_with_hat.png")

KandinskyInpaintCombinedPipeline

class diffusers.KandinskyInpaintCombinedPipeline

< source >

( text_encoder: MultilingualCLIP tokenizer: XLMRobertaTokenizer unet: UNet2DConditionModel scheduler: typing.Union[diffusers.schedulers.scheduling_ddim.DDIMScheduler, diffusers.schedulers.scheduling_ddpm.DDPMScheduler] movq: VQModel prior_prior: PriorTransformer prior_image_encoder: CLIPVisionModelWithProjection prior_text_encoder: CLIPTextModelWithProjection prior_tokenizer: CLIPTokenizer prior_scheduler: UnCLIPScheduler prior_image_processor: CLIPImageProcessor )

Parameters

Combined Pipeline for generation using Kandinsky

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Union[torch.Tensor, PIL.Image.Image, typing.List[torch.Tensor], typing.List[PIL.Image.Image]] mask_image: typing.Union[torch.Tensor, PIL.Image.Image, typing.List[torch.Tensor], typing.List[PIL.Image.Image]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_inference_steps: int = 100 guidance_scale: float = 4.0 num_images_per_prompt: int = 1 height: int = 512 width: int = 512 prior_guidance_scale: float = 4.0 prior_num_inference_steps: int = 25 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 return_dict: bool = True ) → ImagePipelineOutput or tuple

Parameters

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import AutoPipelineForInpainting from diffusers.utils import load_image import torch import numpy as np

pipe = AutoPipelineForInpainting.from_pretrained( "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload()

prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality"

original_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" )

mask = np.zeros((768, 768), dtype=np.float32)

mask[:250, 250:-250] = 1

image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0]

enable_sequential_cpu_offload

< source >

( gpu_id: typing.Optional[int] = None device: typing.Union[torch.device, str] = 'cuda' )

Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to atorch.device('meta') and loaded to GPU only when their specific submodule has its forwardmethod called. Note that offloading happens on a submodule basis. Memory savings are higher than withenable_model_cpu_offload`, but performance is lower.

< > Update on GitHub