DiffEdit (original) (raw)

DiffEdit: Diffusion-based semantic image editing with mask guidance is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.

The abstract from the paper is:

Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.

The original codebase can be found at Xiang-cd/DiffEdit-stable-diffusion, and you can try it out in this demo.

This pipeline was contributed by clarencechen. ❤️

Tips

StableDiffusionDiffEditPipeline

class diffusers.StableDiffusionDiffEditPipeline

< source >

( vae: AutoencoderKL text_encoder: CLIPTextModel tokenizer: CLIPTokenizer unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers safety_checker: StableDiffusionSafetyChecker feature_extractor: CLIPImageProcessor inverse_scheduler: DDIMInverseScheduler requires_safety_checker: bool = True )

Parameters

This is an experimental feature!

Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

The pipeline also inherits the following loading and saving methods:

generate_mask

< source >

( image: Union = None target_prompt: Union = None target_negative_prompt: Union = None target_prompt_embeds: Optional = None target_negative_prompt_embeds: Optional = None source_prompt: Union = None source_negative_prompt: Union = None source_prompt_embeds: Optional = None source_negative_prompt_embeds: Optional = None num_maps_per_mask: Optional = 10 mask_encode_strength: Optional = 0.5 mask_thresholding_ratio: Optional = 3.0 num_inference_steps: int = 50 guidance_scale: float = 7.5 generator: Union = None output_type: Optional = 'np' cross_attention_kwargs: Optional = None ) → List[PIL.Image.Image] or np.array

Parameters

Returns

List[PIL.Image.Image] or np.array

When returning a List[PIL.Image.Image], the list consists of a batch of single-channel binary images with dimensions (height // self.vae_scale_factor, width // self.vae_scale_factor). If it’snp.array, the shape is (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor).

Generate a latent mask given a mask prompt, a target prompt, and an image.

import PIL import requests import torch from io import BytesIO

from diffusers import StableDiffusionDiffEditPipeline

def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"

init_image = download_image(img_url).resize((768, 768))

pipeline = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... )

pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) pipeline.enable_model_cpu_offload()

mask_prompt = "A bowl of fruits" prompt = "A bowl of pears"

mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0]

invert

< source >

( prompt: Union = None image: Union = None num_inference_steps: int = 50 inpaint_strength: float = 0.8 guidance_scale: float = 7.5 negative_prompt: Union = None generator: Union = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None decode_latents: bool = False output_type: Optional = 'pil' return_dict: bool = True callback: Optional = None callback_steps: Optional = 1 cross_attention_kwargs: Optional = None lambda_auto_corr: float = 20.0 lambda_kl: float = 20.0 num_reg_steps: int = 0 num_auto_corr_rolls: int = 5 )

Parameters

Generate inverted latents given a prompt and image.

import PIL import requests import torch from io import BytesIO

from diffusers import StableDiffusionDiffEditPipeline

def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"

init_image = download_image(img_url).resize((768, 768))

pipeline = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... )

pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) pipeline.enable_model_cpu_offload()

prompt = "A bowl of fruits"

inverted_latents = pipeline.invert(image=init_image, prompt=prompt).latents

__call__

< source >

( prompt: Union = None mask_image: Union = None image_latents: Union = None inpaint_strength: Optional = 0.8 num_inference_steps: int = 50 guidance_scale: float = 7.5 negative_prompt: Union = None num_images_per_prompt: Optional = 1 eta: float = 0.0 generator: Union = None latents: Optional = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None output_type: Optional = 'pil' return_dict: bool = True callback: Optional = None callback_steps: int = 1 cross_attention_kwargs: Optional = None clip_skip: int = None ) → StableDiffusionPipelineOutput or tuple

Parameters

If return_dict is True, StableDiffusionPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated images and the second element is a list of bools indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.

The call function to the pipeline for generation.

import PIL import requests import torch from io import BytesIO

from diffusers import StableDiffusionDiffEditPipeline

def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"

init_image = download_image(img_url).resize((768, 768))

pipeline = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... )

pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) pipeline.enable_model_cpu_offload()

mask_prompt = "A bowl of fruits" prompt = "A bowl of pears"

mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0]

encode_prompt

< source >

( prompt device num_images_per_prompt do_classifier_free_guidance negative_prompt = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None lora_scale: Optional = None clip_skip: Optional = None )

Parameters

Encodes the prompt into text encoder hidden states.

StableDiffusionPipelineOutput

class diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput

< source >

( images: Union nsfw_content_detected: Optional )

Parameters

Output class for Stable Diffusion pipelines.

< > Update on GitHub