refactor rotary embedding 3: so it is not on cpu by yiyixuxu · Pull Request #9307 · huggingface/diffusers (original) (raw)
@sayakpaul
Is this a reasonable script? I want to compare the performance against 0.30.1-patch
before we introduce the rotary embedding refractor
import torch import torch.utils.benchmark as benchmark import gc
import time
torch.set_float32_matmul_precision("high") torch._inductor.conv_1x1_as_mm = True torch._inductor.coordinate_descent_tuning = True torch._inductor.epilogue_fusion = False torch._inductor.coordinate_descent_check_all_directions = True
import diffusers from platform import python_version from diffusers import DiffusionPipeline
print(diffusers.version) print(torch.version) print(python_version())
def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}, num_threads=torch.get_num_threads(), ) return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def flush(): """Wipes off memory.""" gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats()
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda") pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
prompt_embeds = torch.load("flux_prompt_embeds.pt") pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")
def run_inference(pipe): _ = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=5, guidance_scale=3.5, max_sequence_length=512, generator=torch.manual_seed(42), height=1024, width=1024, )
flush()
time = benchmark_fn(run_inference) memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. print(f" Execution time: {time} sec") print(f" Memory: {memory} gib")