Add the SDE variant of DPM-Solver and DPM-Solver++ by LuChengTHU · Pull Request #3344 · huggingface/diffusers (original) (raw)
Here is an example script for comparing "dpmsolver++", "sde-dpmsolver++" and "ddpm":
from diffusers import DiffusionPipeline from diffusers.utils import pt_to_pil import torch from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, DPMSolverSDEScheduler
from diffusers.pipelines.deepfloyd_if import fast27_timesteps, smart100_timesteps
stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0
stage_1.enable_model_cpu_offload()
stage 2
stage_2 = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 )
stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0
stage_2.enable_model_cpu_offload()
stage 3
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.version >= 2.0.0
stage_3.enable_model_cpu_offload()
def set_scheduler(stage, scheduler_name): if scheduler_name == 'ddpm': scheduler = DDPMScheduler.from_config(stage.scheduler.config) elif scheduler_name == 'dpm++': scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config) scheduler.config.algorithm_type = 'dpmsolver++' elif scheduler_name == 'sde-dpm++': scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config) scheduler.config.algorithm_type = 'sde-dpmsolver++' stage.scheduler = scheduler return stage
seed = 0
prompt_list = [ 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"', "casual photo of a leaf maple syrup glass container sitting on a wooden table in a log cabin, high depth of field during golden hour as the sunlight shines through the windows, dusty air", 'capybara holding a neon sign with text that reads "capybara podcast", a professional photo of a capybara podcasting, capybara chimera animorph, transformer animal, anamorphic, 8k, 4k, 85 mm, f2.2, photography awards and hyperrealistic, highly detailed, f1.4 lens, 50mm photo, soft light, masterpiece, sharp focus, pretty, hasselblad', "high quality dslr photo, a photo product of a lemon inspired by natural and organic materials, wooden accents, intricately decorated with glowing vines of led lights, inspired by baroque luxury", ]
def generate(stage_1, prompt, steps_1): # text embeds prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) generator = torch.manual_seed(seed) # stage 1 image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=steps_1).images # pt_to_pil(image)[0].save("./if_stage_I.png") return image, prompt_embeds, negative_embeds, generator
def upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator, steps_2, steps_3): # stage 2 image = stage_2( image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=steps_2 ).images # pt_to_pil(image)[0].save("./if_stage_II.png")
# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100, num_inference_steps=steps_3).images
# image[0].save("./if_stage_III.png")
return image
upscale_steps = 25 stage_2 = set_scheduler(stage_2, 'dpm++') stage_3 = set_scheduler(stage_3, 'dpm++')
import os base_dir = "sample_results" if not os.path.exists(base_dir): os.mkdir(base_dir)
for base_steps in [10, 15, 20, 25, 50, 100]: for idx, prompt in enumerate(prompt_list): stage_1 = set_scheduler(stage_1, 'sde-dpm++') image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, base_steps) image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator, upscale_steps, upscale_steps) image_upscale[0].save(os.path.join(base_dir, f'./{idx}sde-dpm++{base_steps}.png'))
stage_1 = set_scheduler(stage_1, 'ddpm')
image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, base_steps)
image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator, upscale_steps, upscale_steps)
image_upscale[0].save(os.path.join(base_dir, f'./{idx}_ddpm_{base_steps}.png'))
stage_1 = set_scheduler(stage_1, 'dpm++')
image, prompt_embeds, negative_embeds, generator = generate(stage_1, prompt, base_steps)
image_upscale = upscale(stage_2, stage_3, image, prompt, prompt_embeds, negative_embeds, generator, upscale_steps, upscale_steps)
image_upscale[0].save(os.path.join(base_dir, f'./{idx}_dpm++_{base_steps}.png'))