Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) by LuChengTHU · Pull Request #3314 · huggingface/diffusers (original) (raw)
Here is the script for testing and comparing different settings:
from diffusers import DiffusionPipeline from diffusers.utils import pt_to_pil import torch from diffusers import DPMSolverMultistepScheduler, DDPMScheduler
from diffusers.pipelines.deepfloyd_if import fast27_timesteps, smart100_timesteps
stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0
stage_1.enable_model_cpu_offload()
stage 2
stage_2 = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 )
stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0
stage_2.enable_model_cpu_offload()
stage 3
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0
stage_3.enable_model_cpu_offload()
def set_scheduler(stage, scheduler_name): if scheduler_name == 'ddpm': scheduler = DDPMScheduler.from_config(stage.scheduler.config) elif scheduler_name == 'dpm++': print(stage.scheduler.config) scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config) scheduler.is_predicting_variance = True stage.scheduler = scheduler return stage
seed = 0
prompt_list = [ 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"', "casual photo of a leaf maple syrup glass container sitting on a wooden table in a log cabin, high depth of field during golden hour as the sunlight shines through the windows, dusty air", 'capybara holding a neon sign with text that reads "capybara podcast", a professional photo of a capybara podcasting, capybara chimera animorph, transformer animal, anamorphic, 8k, 4k, 85 mm, f2.2, photography awards and hyperrealistic, highly detailed, f1.4 lens, 50mm photo, soft light, masterpiece, sharp focus, pretty, hasselblad', "high quality dslr photo, a photo product of a lemon inspired by natural and organic materials, wooden accents, intricately decorated with glowing vines of led lights, inspired by baroque luxury", "paper quilling, extremely detailed, paper quilling of a nordic mountain landscape, 8k rendering", 'letters made of candy on a plate that says "diet"', 'a photo of a violet baseball cap with yellow text: "deep floyd". 50mm lens, photo realism, cine lens. violet baseball cap says "deep floyd". reflections, render. yellow stitch text "deep floyd"', ]
steps_list = [5, 10, 15, 20, 25, 50, 100]
def generate(stage_1, prompt, steps_1): # text embeds prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) generator = torch.manual_seed(seed) # stage 1 image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=steps_1).images # pt_to_pil(image)[0].save("./if_stage_I.png") return image, prompt_embeds, negative_embeds, generator
def upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator, steps_2, steps_3): # stage 2 image = stage_2( image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=steps_2 ).images # pt_to_pil(image)[0].save("./if_stage_II.png")
# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100, num_inference_steps=steps_3).images
# image[0].save("./if_stage_III.png")
return image
def compare_upscale(stage_1, stage_2, stage_3): import os base_dir = 'results_upscale' if not os.path.exists(base_dir): os.mkdir(base_dir)
ddpm_base_steps = 100
for idx, prompt in enumerate(prompt_list):
image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, ddpm_base_steps)
for steps in steps_list:
generator_copy = torch.Generator().set_state(generator.get_state())
image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, steps, steps)
image_upscale[0].save(os.path.join(base_dir, f'{idx}_ddpm_{steps}.png'))
stage_2 = set_scheduler(stage_2, 'dpm++')
stage_3 = set_scheduler(stage_3, 'dpm++')
for idx, prompt in enumerate(prompt_list):
image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, ddpm_base_steps)
for steps in steps_list:
generator_copy = torch.Generator().set_state(generator.get_state())
image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, steps, steps)
image_upscale[0].save(os.path.join(base_dir, f'{idx}_dpm++_{steps}.png'))
compare_upscale(stage_1, stage_2, stage_3)
def compare_base(stage_1, stage_2, stage_3): import os base_dir = 'results_base' if not os.path.exists(base_dir): os.mkdir(base_dir)
stage_2 = set_scheduler(stage_2, 'dpm++')
stage_3 = set_scheduler(stage_3, 'dpm++')
upscale_steps = 25
for idx, prompt in enumerate(prompt_list):
for steps in steps_list:
image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, steps)
generator_copy = torch.Generator().set_state(generator.get_state())
image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, upscale_steps, upscale_steps)
image_upscale[0].save(os.path.join(base_dir, f'{idx}_ddpm_{steps}.png'))
stage_1 = set_scheduler(stage_1, 'dpm++')
for idx, prompt in enumerate(prompt_list):
for steps in steps_list:
image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, steps)
generator_copy = torch.Generator().set_state(generator.get_state())
image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator_copy, upscale_steps, upscale_steps)
image_upscale[0].save(os.path.join(base_dir, f'{idx}_dpm++_{steps}.png'))
compare_base(stage_1, stage_2, stage_3)