Inpainting (original) (raw)

LoRA

The Stable Diffusion model can also be applied to inpainting which lets you edit specific parts of an image by providing a mask and a text prompt using Stable Diffusion.

Tips

It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such as runwayml/stable-diffusion-inpainting. Default text-to-image Stable Diffusion checkpoints, such asstable-diffusion-v1-5/stable-diffusion-v1-5 are also compatible but they might be less performant.

Make sure to check out the Stable Diffusion Tips section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!

If you’re interested in using one of the official checkpoints for a task, explore the CompVis, Runway, and Stability AI Hub organizations!

StableDiffusionInpaintPipeline

class diffusers.StableDiffusionInpaintPipeline

< source >

( vae: typing.Union[diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL, diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL] text_encoder: CLIPTextModel tokenizer: CLIPTokenizer unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers safety_checker: StableDiffusionSafetyChecker feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection = None requires_safety_checker: bool = True )

Parameters

Pipeline for text-guided image inpainting using Stable Diffusion.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

The pipeline also inherits the following loading methods:

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None mask_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None masked_image_latents: Tensor = None height: typing.Optional[int] = None width: typing.Optional[int] = None padding_mask_crop: typing.Optional[int] = None strength: float = 1.0 num_inference_steps: int = 50 timesteps: typing.List[int] = None sigmas: typing.List[float] = None guidance_scale: float = 7.5 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: typing.Optional[int] = 1 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None clip_skip: int = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] **kwargs ) → StableDiffusionPipelineOutput or tuple

Parameters

If return_dict is True, StableDiffusionPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images and the second element is a list of bools indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.

The call function to the pipeline for generation.

Examples:

import PIL import requests import torch from io import BytesIO

from diffusers import StableDiffusionInpaintPipeline

def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512))

pipe = StableDiffusionInpaintPipeline.from_pretrained( ... "stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16 ... ) pipe = pipe.to("cuda")

prompt = "Face of a yellow cat, high resolution, sitting on a park bench" image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]

enable_attention_slicing

< source >

( slice_size: typing.Union[int, str, NoneType] = 'auto' )

Parameters

Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. For more than one attention head, the computation is performed sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.

⚠️ Don’t enable attention slicing if you’re already using scaled_dot_product_attention (SDPA) from PyTorch 2.0 or xFormers. These attention computations are already very memory efficient so you won’t need to enable this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!

Examples:

import torch from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained( ... "stable-diffusion-v1-5/stable-diffusion-v1-5", ... torch_dtype=torch.float16, ... use_safetensors=True, ... )

prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() image = pipe(prompt).images[0]

Disable sliced attention computation. If enable_attention_slicing was previously called, attention is computed in one step.

enable_xformers_memory_efficient_attention

< source >

( attention_op: typing.Optional[typing.Callable] = None )

Parameters

Enable memory efficient attention from xFormers. When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed.

⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent.

Examples:

import torch from diffusers import DiffusionPipeline from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) pipe = pipe.to("cuda") pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)

pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

disable_xformers_memory_efficient_attention

< source >

( )

Disable memory efficient attention from xFormers.

load_textual_inversion

< source >

( pretrained_model_name_or_path: typing.Union[str, typing.List[str], typing.Dict[str, torch.Tensor], typing.List[typing.Dict[str, torch.Tensor]]] token: typing.Union[typing.List[str], str, NoneType] = None tokenizer: typing.Optional[ForwardRef('PreTrainedTokenizer')] = None text_encoder: typing.Optional[ForwardRef('PreTrainedModel')] = None **kwargs )

Parameters

Load Textual Inversion embeddings into the text encoder of StableDiffusionPipeline (both 🤗 Diffusers and Automatic1111 formats are supported).

Example:

To load a Textual Inversion embedding vector in 🤗 Diffusers format:

from diffusers import StableDiffusionPipeline import torch

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

pipe.load_textual_inversion("sd-concepts-library/cat-toy")

prompt = "A backpack"

image = pipe(prompt, num_inference_steps=50).images[0] image.save("cat-backpack.png")

To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first (for example from civitAI) and then load the vector

locally:

from diffusers import StableDiffusionPipeline import torch

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")

prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."

image = pipe(prompt, num_inference_steps=50).images[0] image.save("character.png")

load_lora_weights

< source >

( pretrained_model_name_or_path_or_dict: typing.Union[str, typing.Dict[str, torch.Tensor]] adapter_name = None hotswap: bool = False **kwargs )

Parameters

Load LoRA weights specified in pretrained_model_name_or_path_or_dict into self.unet andself.text_encoder.

All kwargs are forwarded to self.lora_state_dict.

See lora_state_dict() for more details on how the state dict is loaded.

See load_lora_into_unet() for more details on how the state dict is loaded into self.unet.

See load_lora_into_text_encoder() for more details on how the state dict is loaded into self.text_encoder.

save_lora_weights

< source >

( save_directory: typing.Union[str, os.PathLike] unet_lora_layers: typing.Dict[str, typing.Union[torch.nn.modules.module.Module, torch.Tensor]] = None text_encoder_lora_layers: typing.Dict[str, torch.nn.modules.module.Module] = None is_main_process: bool = True weight_name: str = None save_function: typing.Callable = None safe_serialization: bool = True )

Parameters

Save the LoRA parameters corresponding to the UNet and text encoder.

encode_prompt

< source >

( prompt device num_images_per_prompt do_classifier_free_guidance negative_prompt = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None lora_scale: typing.Optional[float] = None clip_skip: typing.Optional[int] = None )

Parameters

Encodes the prompt into text encoder hidden states.

get_guidance_scale_embedding

< source >

( w: Tensor embedding_dim: int = 512 dtype: dtype = torch.float32 ) → torch.Tensor

Parameters

Embedding vectors with shape (len(w), embedding_dim).

See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

StableDiffusionPipelineOutput

class diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput

< source >

( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] nsfw_content_detected: typing.Optional[typing.List[bool]] )

Parameters

Output class for Stable Diffusion pipelines.

FlaxStableDiffusionInpaintPipeline

class diffusers.FlaxStableDiffusionInpaintPipeline

< source >

( vae: FlaxAutoencoderKL text_encoder: FlaxCLIPTextModel tokenizer: CLIPTokenizer unet: FlaxUNet2DConditionModel scheduler: typing.Union[diffusers.schedulers.scheduling_ddim_flax.FlaxDDIMScheduler, diffusers.schedulers.scheduling_pndm_flax.FlaxPNDMScheduler, diffusers.schedulers.scheduling_lms_discrete_flax.FlaxLMSDiscreteScheduler, diffusers.schedulers.scheduling_dpmsolver_multistep_flax.FlaxDPMSolverMultistepScheduler] safety_checker: FlaxStableDiffusionSafetyChecker feature_extractor: CLIPImageProcessor dtype: dtype = <class 'jax.numpy.float32'> )

Parameters

Flax-based pipeline for text-guided image inpainting using Stable Diffusion.

🧪 This is an experimental feature!

This model inherits from FlaxDiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( prompt_ids: Array mask: Array masked_image: Array params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] prng_seed: Array num_inference_steps: int = 50 height: typing.Optional[int] = None width: typing.Optional[int] = None guidance_scale: typing.Union[float, jax.Array] = 7.5 latents: Array = None neg_prompt_ids: Array = None return_dict: bool = True jit: bool = False ) → FlaxStableDiffusionPipelineOutput or tuple

Parameters

If return_dict is True, FlaxStableDiffusionPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images and the second element is a list of bools indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.

Function invoked when calling the pipeline for generation.

Examples:

import jax import numpy as np from flax.jax_utils import replicate from flax.training.common_utils import shard import PIL import requests from io import BytesIO from diffusers import FlaxStableDiffusionInpaintPipeline

def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512))

pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( ... "xvjiarui/stable-diffusion-2-inpainting" ... )

prompt = "Face of a yellow cat, high resolution, sitting on a park bench" prng_seed = jax.random.PRNGKey(0) num_inference_steps = 50

num_samples = jax.device_count() prompt = num_samples * [prompt] init_image = num_samples * [init_image] mask_image = num_samples * [mask_image] prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( ... prompt, init_image, mask_image ... )

params = replicate(params) prng_seed = jax.random.split(prng_seed, jax.device_count()) prompt_ids = shard(prompt_ids) processed_masked_images = shard(processed_masked_images) processed_masks = shard(processed_masks)

images = pipeline( ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True ... ).images images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

FlaxStableDiffusionPipelineOutput

class diffusers.pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput

< source >

( images: ndarray nsfw_content_detected: typing.List[bool] )

Parameters

Output class for Flax-based Stable Diffusion pipelines.

Returns a new object replacing the specified fields with new values.

< > Update on GitHub