ControlNetUnion (original) (raw)

class diffusers.StableDiffusionXLControlNetUnionImg2ImgPipeline

< source >

( vae: AutoencoderKL text_encoder: CLIPTextModel text_encoder_2: CLIPTextModelWithProjection tokenizer: CLIPTokenizer tokenizer_2: CLIPTokenizer unet: UNet2DConditionModel controlnet: ControlNetUnionModel scheduler: KarrasDiffusionSchedulers requires_aesthetics_score: bool = False force_zeros_for_empty_prompt: bool = True add_watermarker: typing.Optional[bool] = None feature_extractor: CLIPImageProcessor = None image_encoder: CLIPVisionModelWithProjection = None )

Parameters

Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None prompt_2: typing.Union[str, typing.List[str], NoneType] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None control_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None height: typing.Optional[int] = None width: typing.Optional[int] = None strength: float = 0.8 num_inference_steps: int = 50 guidance_scale: float = 5.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None negative_prompt_2: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: typing.Optional[int] = 1 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None pooled_prompt_embeds: typing.Optional[torch.Tensor] = None negative_pooled_prompt_embeds: typing.Optional[torch.Tensor] = None ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None controlnet_conditioning_scale: typing.Union[float, typing.List[float]] = 0.8 guess_mode: bool = False control_guidance_start: typing.Union[float, typing.List[float]] = 0.0 control_guidance_end: typing.Union[float, typing.List[float]] = 1.0 control_mode: typing.Union[int, typing.List[int], NoneType] = None original_size: typing.Tuple[int, int] = None crops_coords_top_left: typing.Tuple[int, int] = (0, 0) target_size: typing.Tuple[int, int] = None negative_original_size: typing.Optional[typing.Tuple[int, int]] = None negative_crops_coords_top_left: typing.Tuple[int, int] = (0, 0) negative_target_size: typing.Optional[typing.Tuple[int, int]] = None aesthetic_score: float = 6.0 negative_aesthetic_score: float = 2.5 clip_skip: typing.Optional[int] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] **kwargs ) → StableDiffusionPipelineOutput or tuple

Parameters

StableDiffusionPipelineOutput if return_dict is True, otherwise a tuplecontaining the output images.

Function invoked when calling the pipeline for generation.

Examples:

from diffusers import ( StableDiffusionXLControlNetUnionImg2ImgPipeline, ControlNetUnionModel, AutoencoderKL, ) from diffusers.utils import load_image import torch from PIL import Image import numpy as np

prompt = "A cat"

image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" )

controlnet = ControlNetUnionModel.from_pretrained( "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16 ) vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) pipe = StableDiffusionXLControlNetUnionImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16, variant="fp16", ).to("cuda")

height = image.height width = image.width ratio = np.sqrt(1024.0 * 1024.0 / (width * height))

scale_image_factor = 3 base_factor = 16 factor = scale_image_factor * base_factor W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor image = image.resize((W, H)) target_width = W // scale_image_factor target_height = H // scale_image_factor images = [] crops_coords_list = [ (0, 0), (0, width // 2), (height // 2, 0), (width // 2, height // 2), 0, 0, 0, 0, 0, ] for i in range(scale_image_factor): for j in range(scale_image_factor): left = j * target_width top = i * target_height right = left + target_width bottom = top + target_height cropped_image = image.crop((left, top, right, bottom)) cropped_image = cropped_image.resize((W, H)) images.append(cropped_image)

result_images = [] for sub_img, crops_coords in zip(images, crops_coords_list): new_width, new_height = W, H out = pipe( prompt=[prompt] * 1, image=sub_img, control_image=[sub_img], control_mode=[6], width=new_width, height=new_height, num_inference_steps=30, crops_coords_top_left=(W, H), target_size=(W, H), original_size=(W * 2, H * 2), ) result_images.append(out.images[0]) new_im = Image.new("RGB", (new_width * scale_image_factor, new_height * scale_image_factor)) new_im.paste(result_images[0], (0, 0)) new_im.paste(result_images[1], (new_width, 0)) new_im.paste(result_images[2], (new_width * 2, 0)) new_im.paste(result_images[3], (0, new_height)) new_im.paste(result_images[4], (new_width, new_height)) new_im.paste(result_images[5], (new_width * 2, new_height)) new_im.paste(result_images[6], (0, new_height * 2)) new_im.paste(result_images[7], (new_width, new_height * 2)) new_im.paste(result_images[8], (new_width * 2, new_height * 2))

encode_prompt

< source >

( prompt: str prompt_2: typing.Optional[str] = None device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 do_classifier_free_guidance: bool = True negative_prompt: typing.Optional[str] = None negative_prompt_2: typing.Optional[str] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None pooled_prompt_embeds: typing.Optional[torch.Tensor] = None negative_pooled_prompt_embeds: typing.Optional[torch.Tensor] = None lora_scale: typing.Optional[float] = None clip_skip: typing.Optional[int] = None )

Parameters

Encodes the prompt into text encoder hidden states.