Allow specifying denoising_start and denoising_end as integers representing the discrete timesteps, fixing the XL ensemble not working for many schedulers by AmericanPresidentJimmyCarter · Pull Request #4115 · huggingface/diffusers (original) (raw)

As far as I can see the following:

image = pipe(prompt=prompt, output_type="latent" if use_refiner else "pil", denoising_end=200, num_inference_steps=100, generator=random_generator).images[0] image = refiner(prompt=prompt, image=image[None, :], denoising_start=200, num_inference_steps=100).images[0]

is exactly the same as currently doing:

num_inference_steps = 100 high_noise_frac = 0.8

image = pipe_high_noise(prompt=prompt, num_inference_steps=num_inference_steps, denoising_end=high_noise_frac, output_type="latent").images image = pipe_low_noise(prompt=prompt, num_inference_steps=num_inference_steps, denoising_start=high_noise_frac, image=image).images[0]

No, it is absolutely not the same thing. You can observe this by simply adding the following line to print the timesteps that are selected in the pipeline.

    with self.progress_bar(total=num_inference_steps) as progress_bar:

For both the StableDiffusionXLPipeline and StableDiffusionXLImg2ImgPipeline.

Here is the code for you to verify:

from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline from diffusers.schedulers import DPMSolverMultistepScheduler import torch

scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", num_train_timesteps=1000, trained_betas=None, thresholding=False, algorithm_type="dpmsolver++", solver_type="midpoint", lower_order_final=True, use_karras_sigmas=True, )

hf_cache_dir = '/path/to/cache'

pipe = StableDiffusionXLPipeline.from_pretrained( f"{hf_cache_dir}/stable-diffusion-xl-base-0.9", variant="fp16", use_safetensors=True, torch_dtype=torch.bfloat16, ) pipe.enable_model_cpu_offload(0) pipe.enable_vae_slicing() pipe.enable_vae_tiling()

refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( f"{hf_cache_dir}/stable-diffusion-xl-refiner-0.9", use_safetensors=True, variant="fp16", torch_dtype=torch.bfloat16, ) refiner.enable_model_cpu_offload(0) refiner.enable_vae_slicing() refiner.enable_vae_tiling()

pipe.register_modules(scheduler=scheduler) refiner.register_modules(scheduler=scheduler)

with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16): use_refiner = True

random_generator = torch.Generator()
random_generator.manual_seed(42)
prompt = "photo of an astronaut standing in a jungle, muted colors, detailed, 8k, realistic, helmet reflections"
n_prompt = "blurry, ugly, low quality, jpeg compression, drawing, sketch, painting, cartoon, face"
image = pipe(prompt=prompt, negative_prompt=n_prompt, output_type="pil",
    num_inference_steps=100, generator=random_generator, guidance_scale=6.5).images[0]
image.save('spaceman100_1.png')

random_generator = torch.Generator()
random_generator.manual_seed(42)
image = pipe(prompt=prompt, negative_prompt=n_prompt, output_type="latent" if use_refiner else "pil",
    denoising_end=200, num_inference_steps=100, generator=random_generator, guidance_scale=6.5).images[0]
image = refiner(prompt=prompt, image=image[None, :], denoising_start=200, num_inference_steps=100,
    guidance_scale=6.5).images[0]
image.save('spaceman100_2.png')

random_generator = torch.Generator()
random_generator.manual_seed(42)
prompt = "photo of an astronaut standing in a jungle, muted colors, detailed, 8k, realistic, helmet reflections"
n_prompt = "blurry, ugly, low quality, jpeg compression, drawing, sketch, painting, cartoon, face"
image = pipe(prompt=prompt, negative_prompt=n_prompt, output_type="pil",
    num_inference_steps=100, generator=random_generator, guidance_scale=6.5).images[0]
image.save('spaceman100_1_other.png')

random_generator = torch.Generator()
random_generator.manual_seed(42)
high_noise_frac = 0.8
image = pipe(prompt=prompt, negative_prompt=n_prompt, output_type="latent" if use_refiner else "pil",
    denoising_end=high_noise_frac, num_inference_steps=100, generator=random_generator, guidance_scale=6.5).images[0]
image = refiner(prompt=prompt, image=image[None, :], denoising_start=high_noise_frac, num_inference_steps=100,
    guidance_scale=6.5).images[0]
image.save('spaceman100_2_other.png')

Here is the output, showing drastically different timesteps used for strength versus timestep selection. This is the first part where I use timesteps:

Loading pipeline components...: 100%|█████████████| 7/7 [00:03<00:00,  2.04it/s]
Loading pipeline components...: 100%|█████████████| 5/5 [00:01<00:00,  2.58it/s]
  0%|                                                   | 0/100 [00:00<?, ?it/s]timesteps tensor([999, 992, 985, 978, 971, 963, 956, 948, 940, 933, 924, 916, 908, 899,
        891, 882, 873, 864, 854, 845, 835, 825, 815, 805, 794, 783, 772, 761,
        749, 737, 725, 713, 700, 687, 674, 660, 646, 632, 618, 603, 587, 572,
        556, 540, 523, 506, 489, 472, 454, 436, 418, 400, 382, 364, 345, 327,
        309, 291, 273, 256, 239, 222, 205, 190, 174, 160, 146, 133, 120, 108,
         97,  87,  78,  69,  61,  53,  47,  41,  35,  31,  26,  23,  19,  16,
         14,  12,  10,   8,   7,   6,   4,   3,   2,   1,   0],
       device='cuda:0')
 95%|███████████████████████████████████████▉  | 95/100 [00:45<00:02,  2.09it/s]
  0%|                                                    | 0/63 [00:00<?, ?it/s]timesteps tensor([999, 992, 985, 978, 971, 963, 956, 948, 940, 933, 924, 916, 908, 899,
        891, 882, 873, 864, 854, 845, 835, 825, 815, 805, 794, 783, 772, 761,
        749, 737, 725, 713, 700, 687, 674, 660, 646, 632, 618, 603, 587, 572,
        556, 540, 523, 506, 489, 472, 454, 436, 418, 400, 382, 364, 345, 327,
        309, 291, 273, 256, 239, 222, 205], device='cuda:0')
100%|███████████████████████████████████████████| 63/63 [00:36<00:00,  1.70it/s]
  0%|                                                    | 0/32 [00:00<?, ?it/s]timesteps tensor([190, 174, 160, 146, 133, 120, 108,  97,  87,  78,  69,  61,  53,  47,
         41,  35,  31,  26,  23,  19,  16,  14,  12,  10,   8,   7,   6,   4,
          3,   2,   1,   0])
100%|███████████████████████████████████████████| 32/32 [00:22<00:00,  1.43it/s]

Here is the second part where I use high_noise_frac.

  0%|                                                   | 0/100 [00:00<?, ?it/s]timesteps tensor([999, 992, 985, 978, 971, 963, 956, 948, 940, 933, 924, 916, 908, 899,
        891, 882, 873, 864, 854, 845, 835, 825, 815, 805, 794, 783, 772, 761,
        749, 737, 725, 713, 700, 687, 674, 660, 646, 632, 618, 603, 587, 572,
        556, 540, 523, 506, 489, 472, 454, 436, 418, 400, 382, 364, 345, 327,
        309, 291, 273, 256, 239, 222, 205, 190, 174, 160, 146, 133, 120, 108,
         97,  87,  78,  69,  61,  53,  47,  41,  35,  31,  26,  23,  19,  16,
         14,  12,  10,   8,   7,   6,   4,   3,   2,   1,   0],
       device='cuda:0')
 95%|███████████████████████████████████████▉  | 95/100 [00:38<00:02,  2.48it/s]
  0%|                                                    | 0/80 [00:00<?, ?it/s]timesteps tensor([999, 992, 985, 978, 971, 963, 956, 948, 940, 933, 924, 916, 908, 899,
        891, 882, 873, 864, 854, 845, 835, 825, 815, 805, 794, 783, 772, 761,
        749, 737, 725, 713, 700, 687, 674, 660, 646, 632, 618, 603, 587, 572,
        556, 540, 523, 506, 489, 472, 454, 436, 418, 400, 382, 364, 345, 327,
        309, 291, 273, 256, 239, 222, 205, 190, 174, 160, 146, 133, 120, 108,
         97,  87,  78,  69,  61,  53,  47,  41,  35,  31], device='cuda:0')
100%|███████████████████████████████████████████| 80/80 [00:28<00:00,  2.77it/s]
  0%|                                                    | 0/20 [00:00<?, ?it/s]timesteps tensor([26, 23, 19, 16, 14, 12, 10,  8,  7,  6,  4,  3,  2,  1,  0],
       device='cuda:0')
 75%|████████████████████████████████▎          | 15/20 [00:06<00:02,  2.29it/s]

high_noise_frac is using totally different timesteps for the refiner unet! In this case, with high_noise_frac we use [26, 23, 19, 16, 14, 12, 10, 8, 7, 6, 4, 3, 2, 1, 0], whereas with the timestep cutoff of 200 we use [190, 174, 160, 146, 133, 120, 108, 97, 87, 78, 69, 61, 53, 47, 41, 35, 31, 26, 23, 19, 16, 14, 12, 10, 8, 7, 6, 4, 3, 2, 1, 0]. In fact, you are using the refiner so little that you may as well not even switch to it!

What you are recommending instead (fraction) is a misunderstanding of both the training of Stable Diffusion XL's refiner and of how an ensemble of experts works.