QwenImage (original) (raw)

LoRA

Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.

Qwen-Image comes in the following variants:

model type model id
Qwen-Image Qwen/Qwen-Image
Qwen-Image-Edit Qwen/Qwen-Image-Edit
Qwen-Image-Edit Plus Qwen/Qwen-Image-Edit-2509

[!TIP][Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.

LoRA for faster inference

Use a LoRA from lightx2v/Qwen-Image-Lightning to speed up inference by reducing the number of steps. Refer to the code snippet below:

Code

from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler import torch import math

ckpt_id = "Qwen/Qwen-Image"

scheduler_config = { "base_image_seq_len": 256, "base_shift": math.log(3),
"invert_sigmas": False, "max_image_seq_len": 8192, "max_shift": math.log(3),
"num_train_timesteps": 1000, "shift": 1.0, "shift_terminal": None,
"stochastic_sampling": False, "time_shift_type": "exponential", "use_beta_sigmas": False, "use_dynamic_shifting": True, "use_exponential_sigmas": False, "use_karras_sigmas": False, } scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) pipe = DiffusionPipeline.from_pretrained( ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16 ).to("cuda") pipe.load_lora_weights( "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors" )

prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition." negative_prompt = " " image = pipe( prompt=prompt, negative_prompt=negative_prompt, width=1024, height=1024, num_inference_steps=8, true_cfg_scale=1.0, generator=torch.manual_seed(0), ).images[0] image.save("qwen_fewsteps.png")

The guidance_scale parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing guidance_scale to the pipeline is ineffective. To enable classifier-free guidance, please pass true_cfg_scale and negative_prompt (even an empty negative prompt like ” ”) should enable classifier-free guidance computations.

Multi-image reference with QwenImageEditPlusPipeline

With QwenImageEditPlusPipeline, one can provide multiple images as input reference.

import torch from PIL import Image from diffusers import QwenImageEditPlusPipeline from diffusers.utils import load_image

pipe = QwenImageEditPlusPipeline.from_pretrained( "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16 ).to("cuda")

image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( image=[image_1, image_2], prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0]

QwenImagePipeline

class diffusers.QwenImagePipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer transformer: QwenImageTransformer2DModel )

Parameters

The QwenImage pipeline for text-to-image generation.

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None true_cfg_scale: float = 4.0 height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import QwenImagePipeline

pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "A cat holding a sign that says hello world"

image = pipe(prompt, num_inference_steps=50).images[0] image.save("qwenimage.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImageImg2ImgPipeline

class diffusers.QwenImageImg2ImgPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer transformer: QwenImageTransformer2DModel )

Parameters

The QwenImage pipeline for text-to-image generation.

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None true_cfg_scale: float = 4.0 image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None height: typing.Optional[int] = None width: typing.Optional[int] = None strength: float = 0.6 num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import QwenImageImg2ImgPipeline from diffusers.utils import load_image

pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) pipe = pipe.to("cuda") url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" init_image = load_image(url).resize((1024, 1024)) prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney" images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95).images[0] images.save("qwenimage_img2img.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImageInpaintPipeline

class diffusers.QwenImageInpaintPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer transformer: QwenImageTransformer2DModel )

Parameters

The QwenImage pipeline for text-to-image generation.

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None true_cfg_scale: float = 4.0 image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None mask_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None masked_image_latents: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None height: typing.Optional[int] = None width: typing.Optional[int] = None padding_mask_crop: typing.Optional[int] = None strength: float = 0.6 num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import QwenImageInpaintPipeline from diffusers.utils import load_image

pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "Face of a yellow cat, high resolution, sitting on a park bench" img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" source = load_image(img_url) mask = load_image(mask_url) image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85).images[0] image.save("qwenimage_inpainting.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImageEditPipeline

class diffusers.QwenImageEditPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer processor: Qwen2VLProcessor transformer: QwenImageTransformer2DModel )

Parameters

The Qwen-Image-Edit pipeline for image editing.

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None true_cfg_scale: float = 4.0 height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from PIL import Image from diffusers import QwenImageEditPipeline from diffusers.utils import load_image

pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) pipe.to("cuda") image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" ... ).convert("RGB") prompt = ( ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" ... )

image = pipe(image, prompt, num_inference_steps=50).images[0] image.save("qwenimage_edit.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Optional[torch.Tensor] = None device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImageEditInpaintPipeline

class diffusers.QwenImageEditInpaintPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer processor: Qwen2VLProcessor transformer: QwenImageTransformer2DModel )

Parameters

The Qwen-Image-Edit pipeline for image editing.

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None mask_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None masked_image_latents: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None true_cfg_scale: float = 4.0 height: typing.Optional[int] = None width: typing.Optional[int] = None padding_mask_crop: typing.Optional[int] = None strength: float = 0.6 num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from PIL import Image from diffusers import QwenImageEditInpaintPipeline from diffusers.utils import load_image

pipe = QwenImageEditInpaintPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" source = load_image(img_url) mask = load_image(mask_url) image = pipe( ... prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=1.0, num_inference_steps=50 ... ).images[0] image.save("qwenimage_inpainting.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Optional[torch.Tensor] = None device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImageControlNetPipeline

class diffusers.QwenImageControlNetPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer transformer: QwenImageTransformer2DModel controlnet: typing.Union[diffusers.models.controlnets.controlnet_qwenimage.QwenImageControlNetModel, diffusers.models.controlnets.controlnet_qwenimage.QwenImageMultiControlNetModel] )

Parameters

The QwenImage pipeline for text-to-image generation.

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None true_cfg_scale: float = 4.0 height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None control_guidance_start: typing.Union[float, typing.List[float]] = 0.0 control_guidance_end: typing.Union[float, typing.List[float]] = 1.0 control_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None controlnet_conditioning_scale: typing.Union[float, typing.List[float]] = 1.0 num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers.utils import load_image from diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageControlNetPipeline

controlnet = QwenImageControlNetModel.from_pretrained( ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16 ... ) pipe = QwenImageControlNetPipeline.from_pretrained( ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16 ... ) pipe.to("cuda") prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation." negative_prompt = " " control_image = load_image( ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png" ... )

image = pipe( ... prompt, ... negative_prompt=negative_prompt, ... control_image=control_image, ... controlnet_conditioning_scale=1.0, ... num_inference_steps=30, ... true_cfg_scale=4.0, ... ).images[0] image.save("qwenimage_cn_union.png")

controlnet = QwenImageControlNetModel.from_pretrained( ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16 ... ) controlnet = QwenImageMultiControlNetModel([controlnet]) pipe = QwenImageControlNetPipeline.from_pretrained( ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16 ... ) pipe.to("cuda") prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation." negative_prompt = " " control_image = load_image( ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png" ... )

image = pipe( ... prompt, ... negative_prompt=negative_prompt, ... control_image=[control_image, control_image], ... controlnet_conditioning_scale=[0.5, 0.5], ... num_inference_steps=30, ... true_cfg_scale=4.0, ... ).images[0] image.save("qwenimage_cn_union_multi.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImageEditPlusPipeline

class diffusers.QwenImageEditPlusPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKLQwenImage text_encoder: Qwen2_5_VLForConditionalGeneration tokenizer: Qwen2Tokenizer processor: Qwen2VLProcessor transformer: QwenImageTransformer2DModel )

Parameters

The Qwen-Image-Edit pipeline for image editing.

__call__

< source >

( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None true_cfg_scale: float = 4.0 height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: typing.Optional[float] = None num_images_per_prompt: int = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.qwenimage.QwenImagePipelineOutput or tuple

Parameters

Returns

~pipelines.qwenimage.QwenImagePipelineOutput or tuple

~pipelines.qwenimage.QwenImagePipelineOutput if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from PIL import Image from diffusers import QwenImageEditPlusPipeline from diffusers.utils import load_image

pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16) pipe.to("cuda") image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" ... ).convert("RGB") prompt = ( ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" ... )

image = pipe(image, prompt, num_inference_steps=50).images[0] image.save("qwenimage_edit_plus.png")

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] image: typing.Optional[torch.Tensor] = None device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None prompt_embeds_mask: typing.Optional[torch.Tensor] = None max_sequence_length: int = 1024 )

Parameters

QwenImagePipelineOutput

class diffusers.pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

< source >

( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] )

Parameters

Output class for Stable Diffusion pipelines.

Update on GitHub