[do not merge] testing rotary embedding + torch.compile by yiyixuxu · Pull Request #9321 · huggingface/diffusers (original) (raw)
this is script I use to generate trace
import os os.environ['TORCH_LOGS'] = 'graph_breaks, dynamo, recompiles' os.environ['TORCHDYNAMO_VERBOSE'] = '1'
import torch
import logging logging.basicConfig(level=logging.INFO)
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler import gc
torch.set_float32_matmul_precision("high") torch._inductor.conv_1x1_as_mm = True torch._inductor.coordinate_descent_tuning = True torch._inductor.epilogue_fusion = False torch._inductor.coordinate_descent_check_all_directions = True
import diffusers from platform import python_version from diffusers import DiffusionPipeline
print(diffusers.version) print(torch.version) print(python_version())
def profiler_runner(fn, *args, **kwargs): with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, on_trace_ready=tensorboard_trace_handler("./yiyi_trace")) as prof: result = fn(*args, **kwargs) return result
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda") pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
prompt_embeds = torch.load("flux_prompt_embeds.pt") pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")
def run_inference(pipe): for i in range(5): with record_function(f"pipeline_run_number_{i}"): _ = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=3, guidance_scale=3.5, max_sequence_length=512, generator=torch.manual_seed(42), height=1024, width=1024, )
_ = profiler_runner(run_inference, pipe)
benchmark test
I'm using this testing scripts, some numbers below
( I don't think putting arrange on device improved anything, but there is a difference from 0.30.1-patch so should look into that)
main before the lastest commit (with graph break warning)
# Execution time: 2.287 sec
# Memory: 22.805 gib
main with latest commit (#9307)
# Execution time: 2.256 sec
# Memory: 22.346 gib
in the PR (do arange on gpu)
# Execution time: 2.269 sec
# Memory: 22.348 gib
0.30.1-patch (using the original flux rotary embeds, before this PR #9074)
# Execution time: 2.226 sec
# Memory: 22.346 gib
testing script
import os os.environ['TORCH_LOGS'] = 'graph_breaks'
import torch import torch.utils.benchmark as benchmark import gc
import time
torch.set_float32_matmul_precision("high") torch._inductor.conv_1x1_as_mm = True torch._inductor.coordinate_descent_tuning = True torch._inductor.epilogue_fusion = False torch._inductor.coordinate_descent_check_all_directions = True
import diffusers from platform import python_version from diffusers import DiffusionPipeline
print(diffusers.version) print(torch.version) print(python_version())
def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}, num_threads=torch.get_num_threads(), ) return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def flush(): """Wipes off memory.""" gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats()
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda") pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
prompt_embeds = torch.load("flux_prompt_embeds.pt") pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")
def run_inference(pipe): _ = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=5, guidance_scale=3.5, max_sequence_length=512, generator=torch.manual_seed(42), height=1024, width=1024, )
flush()
for _ in range(5): run_inference(pipe)
flush()
time = benchmark_fn(run_inference, pipe) memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. print(f" Execution time: {time} sec") print(f" Memory: {memory} gib")