Kandinsky 5.0 Image (original) (raw)

Kandinsky 5.0 is a family of diffusion models for Video & Image generation.

Kandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters).

The model introduces several key innovations:

The original codebase can be found at kandinskylab/Kandinsky-5.

Check out the Kandinsky Lab organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.

Available Models

Kandinsky 5.0 Image Lite:

model_id Description Use Cases
kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers 6B image Supervised Fine-Tuned model Highest generation quality
kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers 6B image editing Supervised Fine-Tuned model Highest generation quality
kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers 6B image Base pretrained model Research and fine-tuning
kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers 6B image editing Base pretrained model Research and fine-tuning

Usage Examples

Basic Text-to-Image Generation

import torch from diffusers import Kandinsky5T2IPipeline

model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers" pipe = Kandinsky5T2IPipeline.from_pretrained(model_id) _ = pipe.to(device='cuda',dtype=torch.bfloat16)

prompt = "A fluffy, expressive cat wearing a bright red hat with a soft, slightly textured fabric. The hat should look cozy and well-fitted on the cat’s head. On the front of the hat, add clean, bold white text that reads “SWEET”, clearly visible and neatly centered. Ensure the overall lighting highlights the hat’s color and the cat’s fur details."

output = pipe( prompt=prompt, negative_prompt="", height=1024, width=1024, num_inference_steps=50, guidance_scale=3.5, ).image[0]

Basic Image-to-Image Generation

import torch from diffusers import Kandinsky5I2IPipeline from diffusers.utils import load_image

model_id = "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers" pipe = Kandinsky5I2IPipeline.from_pretrained(model_id)

_ = pipe.to(device='cuda',dtype=torch.bfloat16) pipe.enable_model_cpu_offload()

image = load_image( "https://huggingface.co/kandinsky-community/kandinsky-3/resolve/main/assets/title.jpg?download=true" )

prompt = "Change the background from a winter night scene to a bright summer day. Place the character on a sandy beach with clear blue sky, soft sunlight, and gentle waves in the distance. Replace the winter clothing with a light short-sleeved T-shirt (in soft pastel colors) and casual shorts. Ensure the character’s fur reflects warm daylight instead of cold winter tones. Add small beach details such as seashells, footprints in the sand, and a few scattered beach toys nearby. Keep the oranges in the scene, but place them naturally on the sand." negative_prompt = ""

output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, guidance_scale=3.5, ).image[0]

Kandinsky5T2IPipeline

class diffusers.Kandinsky5T2IPipeline

< source >

( transformer: Kandinsky5Transformer3DModel vae: AutoencoderKL text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2VLProcessor text_encoder_2: CLIPTextModel tokenizer_2: CLIPTokenizer scheduler: FlowMatchEulerDiscreteScheduler )

Parameters

Pipeline for text-to-image generation using Kandinsky 5.0.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: int = 1024 width: int = 1024 num_inference_steps: int = 50 guidance_scale: float = 3.5 num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds_qwen: typing.Optional[torch.Tensor] = None prompt_embeds_clip: typing.Optional[torch.Tensor] = None negative_prompt_embeds_qwen: typing.Optional[torch.Tensor] = None negative_prompt_embeds_clip: typing.Optional[torch.Tensor] = None prompt_cu_seqlens: typing.Optional[torch.Tensor] = None negative_prompt_cu_seqlens: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~KandinskyImagePipelineOutput or tuple

Parameters

Returns

~KandinskyImagePipelineOutput or tuple

If return_dict is True, KandinskyImagePipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

The call function to the pipeline for text-to-image generation.

Examples:

import torch from diffusers import Kandinsky5T2IPipeline

model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers" pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) pipe = pipe.to("cuda")

prompt = "A cat and a dog baking a cake together in a kitchen."

output = pipe( ... prompt=prompt, ... negative_prompt="", ... height=1024, ... width=1024, ... num_inference_steps=50, ... guidance_scale=3.5, ... ).frames[0]

check_inputs

< source >

( prompt negative_prompt height width prompt_embeds_qwen = None prompt_embeds_clip = None negative_prompt_embeds_qwen = None negative_prompt_embeds_clip = None prompt_cu_seqlens = None negative_prompt_cu_seqlens = None callback_on_step_end_tensor_inputs = None max_sequence_length = None )

Parameters

Validate input parameters for the pipeline.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] num_images_per_prompt: int = 1 max_sequence_length: int = 512 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None ) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Parameters

Returns

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Encodes a single prompt (positive or negative) into text encoder hidden states.

This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text representations for image generation.

prepare_latents

< source >

( batch_size: int num_channels_latents: int = 16 height: int = 1024 width: int = 1024 dtype: typing.Optional[torch.dtype] = None device: typing.Optional[torch.device] = None generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None ) → torch.Tensor

Parameters

Prepared latent tensor

Prepare initial latent variables for text-to-image generation.

This method creates random noise latents

Kandinsky5I2IPipeline

class diffusers.Kandinsky5I2IPipeline

< source >

( transformer: Kandinsky5Transformer3DModel vae: AutoencoderKL text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2VLProcessor text_encoder_2: CLIPTextModel tokenizer_2: CLIPTokenizer scheduler: FlowMatchEulerDiscreteScheduler )

Parameters

Pipeline for image-to-image generation using Kandinsky 5.0.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str], NoneType] = None height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 50 guidance_scale: float = 3.5 num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds_qwen: typing.Optional[torch.Tensor] = None prompt_embeds_clip: typing.Optional[torch.Tensor] = None negative_prompt_embeds_qwen: typing.Optional[torch.Tensor] = None negative_prompt_embeds_clip: typing.Optional[torch.Tensor] = None prompt_cu_seqlens: typing.Optional[torch.Tensor] = None negative_prompt_cu_seqlens: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 1024 ) → ~KandinskyImagePipelineOutput or tuple

Parameters

Returns

~KandinskyImagePipelineOutput or tuple

If return_dict is True, KandinskyImagePipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images.

The call function to the pipeline for image-to-image generation.

Examples:

import torch from diffusers import Kandinsky5I2IPipeline

model_id = "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers" pipe = Kandinsky5I2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) pipe = pipe.to("cuda")

prompt = "A cat and a dog baking a cake together in a kitchen."

output = pipe( ... prompt=prompt, ... negative_prompt="", ... height=1024, ... width=1024, ... num_inference_steps=50, ... guidance_scale=3.5, ... ).frames[0]

check_inputs

< source >

( prompt negative_prompt image height width prompt_embeds_qwen = None prompt_embeds_clip = None negative_prompt_embeds_qwen = None negative_prompt_embeds_clip = None prompt_cu_seqlens = None negative_prompt_cu_seqlens = None callback_on_step_end_tensor_inputs = None max_sequence_length = None )

Parameters

Validate input parameters for the pipeline.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] image: Tensor num_images_per_prompt: int = 1 max_sequence_length: int = 1024 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None ) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Parameters

Returns

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Encodes a single prompt (positive or negative) into text encoder hidden states.

This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text representations for image generation.

prepare_latents

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] batch_size: int num_channels_latents: int = 16 height: int = 1024 width: int = 1024 dtype: typing.Optional[torch.dtype] = None device: typing.Optional[torch.device] = None generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None ) → torch.Tensor

Parameters

Prepared latent tensor with encoded image

Prepare initial latent variables for image-to-image generation.

This method creates random noise latents with encoded image,

Citation

author = {Alexander Belykh and Alexander Varlamov and Alexey Letunovskiy and Anastasia Aliaskina and Anastasia Maltseva and Anastasiia Kargapoltseva and Andrey Shutkin and Anna Averchenkova and Anna Dmitrienko and Bulat Akhmatov and Denis Dimitrov and Denis Koposov and Denis Parkhomenko and Dmitrii and Ilya Vasiliev and Ivan Kirillov and Julia Agafonova and Kirill Chernyshev and Kormilitsyn Semen and Lev Novitskiy and Maria Kovaleva and Mikhail Mamaev and Mikhailov and Nikita Kiselev and Nikita Osterov and Nikolai Gerasimenko and Nikolai Vaulin and Olga Kim and Olga Vdovchenko and Polina Gavrilova and Polina Mikhailova and Tatiana Nikulina and Viacheslav Vasilev and Vladimir Arkhipkin and Vladimir Korviakov and Vladimir Polovnikov and Yury Kolabushin},
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
howpublished = {\url{https:
year = 2025

}

Update on GitHub