[Core] add: controlnet support for SDXL by sayakpaul · Pull Request #4038 · huggingface/diffusers (original) (raw)

This PR adds support for ControlNets with SDXL. The two primary components being added to this PR:

However, these seems to be something weird going on here.

I first started training on a small subset of dataset (the circles dataset) with the following command:

export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9" export OUTPUT_DIR="controlnet-sdxl-circles"

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png

accelerate launch train_controlnet_sdxl.py
--pretrained_model_name_or_path=$MODEL_DIR
--output_dir=$OUTPUT_DIR
--dataset_name=fusing/fill50k
--mixed_precision="fp16"
--resolution=1024
--learning_rate=5e-5
--max_train_samples=500
--max_train_steps=1000
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png"
--validation_prompt "red circle with blue background" "cyan circle with brown floral background"
--validation_steps=25
--train_batch_size=1
--gradient_accumulation_steps=4
--report_to="wandb"
--seed=42
--push_to_hub

The trained checkpoints seem to only generate black images: https://huggingface.co/fusing/controlnet-sdxl-circles-fixed (only visible to the diffusers team members).

To further debug this, I tried:

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel from diffusers.utils import load_image import torch

base_ckpt_id = "stabilityai/stable-diffusion-xl-base-0.9" controlnet_ckpt_id = "controlnet-sdxl-circles-fixed"

controlnet = ControlNetModel.from_pretrained( controlnet_ckpt_id, subfolder="checkpoint-500/controlnet", torch_dtype=torch.float16 ).to("cuda")

pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( base_ckpt_id, controlnet=controlnet, torch_dtype=torch.float16 ).to("cuda")

cond_image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png" ) prompt = "red circle with blue background"

image = pipeline(prompt, image=cond_image).images[0] image.save("controlnet@ckpt-500.png")

This doesn't generate the expected results (which is expected since the number of training steps is quite low) but doesn't generate all black images either.

@patrickvonplaten @williamberman could you take a deeper look here?

TODOs