Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) by LuChengTHU · Pull Request #3314 · huggingface/diffusers (original) (raw)

Here is the script for testing and comparing different settings:

from diffusers import DiffusionPipeline from diffusers.utils import pt_to_pil import torch from diffusers import DPMSolverMultistepScheduler, DDPMScheduler

from diffusers.pipelines.deepfloyd_if import fast27_timesteps, smart100_timesteps

stage 1

stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)

stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0

stage_1.enable_model_cpu_offload()

stage 2

stage_2 = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 )

stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0

stage_2.enable_model_cpu_offload()

stage 3

safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)

stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0

stage_3.enable_model_cpu_offload()

def set_scheduler(stage, scheduler_name): if scheduler_name == 'ddpm': scheduler = DDPMScheduler.from_config(stage.scheduler.config) elif scheduler_name == 'dpm++': print(stage.scheduler.config) scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config) scheduler.is_predicting_variance = True stage.scheduler = scheduler return stage

seed = 0

prompt_list = [ 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"', "casual photo of a leaf maple syrup glass container sitting on a wooden table in a log cabin, high depth of field during golden hour as the sunlight shines through the windows, dusty air", 'capybara holding a neon sign with text that reads "capybara podcast", a professional photo of a capybara podcasting, capybara chimera animorph, transformer animal, anamorphic, 8k, 4k, 85 mm, f2.2, photography awards and hyperrealistic, highly detailed, f1.4 lens, 50mm photo, soft light, masterpiece, sharp focus, pretty, hasselblad', "high quality dslr photo, a photo product of a lemon inspired by natural and organic materials, wooden accents, intricately decorated with glowing vines of led lights, inspired by baroque luxury", "paper quilling, extremely detailed, paper quilling of a nordic mountain landscape, 8k rendering", 'letters made of candy on a plate that says "diet"', 'a photo of a violet baseball cap with yellow text: "deep floyd". 50mm lens, photo realism, cine lens. violet baseball cap says "deep floyd". reflections, render. yellow stitch text "deep floyd"', ]

steps_list = [5, 10, 15, 20, 25, 50, 100]

def generate(stage_1, prompt, steps_1): # text embeds prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) generator = torch.manual_seed(seed) # stage 1 image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=steps_1).images # pt_to_pil(image)[0].save("./if_stage_I.png") return image, prompt_embeds, negative_embeds, generator

def upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator, steps_2, steps_3): # stage 2 image = stage_2( image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=steps_2 ).images # pt_to_pil(image)[0].save("./if_stage_II.png")

# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100, num_inference_steps=steps_3).images
# image[0].save("./if_stage_III.png")
return image

def compare_upscale(stage_1, stage_2, stage_3): import os base_dir = 'results_upscale' if not os.path.exists(base_dir): os.mkdir(base_dir)

ddpm_base_steps = 100

for idx, prompt in enumerate(prompt_list):
    image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, ddpm_base_steps)
    for steps in steps_list:
        generator_copy = torch.Generator().set_state(generator.get_state())
        image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, steps, steps)
        image_upscale[0].save(os.path.join(base_dir, f'{idx}_ddpm_{steps}.png'))
stage_2 = set_scheduler(stage_2, 'dpm++')
stage_3 = set_scheduler(stage_3, 'dpm++')

for idx, prompt in enumerate(prompt_list):
    image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, ddpm_base_steps)
    for steps in steps_list:
        generator_copy = torch.Generator().set_state(generator.get_state())
        image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, steps, steps)
        image_upscale[0].save(os.path.join(base_dir, f'{idx}_dpm++_{steps}.png'))

compare_upscale(stage_1, stage_2, stage_3)

def compare_base(stage_1, stage_2, stage_3): import os base_dir = 'results_base' if not os.path.exists(base_dir): os.mkdir(base_dir)

stage_2 = set_scheduler(stage_2, 'dpm++')
stage_3 = set_scheduler(stage_3, 'dpm++')
upscale_steps = 25

for idx, prompt in enumerate(prompt_list):
    for steps in steps_list:
        image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, steps)
        generator_copy = torch.Generator().set_state(generator.get_state())
        image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, upscale_steps, upscale_steps)
        image_upscale[0].save(os.path.join(base_dir, f'{idx}_ddpm_{steps}.png'))

stage_1 = set_scheduler(stage_1, 'dpm++')
for idx, prompt in enumerate(prompt_list):
    for steps in steps_list:
        image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, steps)
        generator_copy = torch.Generator().set_state(generator.get_state())
        image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, upscale_steps, upscale_steps)
        image_upscale[0].save(os.path.join(base_dir, f'{idx}_dpm++_{steps}.png'))

compare_base(stage_1, stage_2, stage_3)