[do not merge] testing rotary embedding + torch.compile by yiyixuxu · Pull Request #9321 · huggingface/diffusers (original) (raw)

this is script I use to generate trace

import os os.environ['TORCH_LOGS'] = 'graph_breaks, dynamo, recompiles' os.environ['TORCHDYNAMO_VERBOSE'] = '1'

import torch

import logging logging.basicConfig(level=logging.INFO)

from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler import gc

torch.set_float32_matmul_precision("high") torch._inductor.conv_1x1_as_mm = True torch._inductor.coordinate_descent_tuning = True torch._inductor.epilogue_fusion = False torch._inductor.coordinate_descent_check_all_directions = True

import diffusers from platform import python_version from diffusers import DiffusionPipeline

print(diffusers.version) print(torch.version) print(python_version())

def profiler_runner(fn, *args, **kwargs): with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, on_trace_ready=tensorboard_trace_handler("./yiyi_trace")) as prof: result = fn(*args, **kwargs) return result

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda") pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt_embeds = torch.load("flux_prompt_embeds.pt") pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")

def run_inference(pipe): for i in range(5): with record_function(f"pipeline_run_number_{i}"): _ = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=3, guidance_scale=3.5, max_sequence_length=512, generator=torch.manual_seed(42), height=1024, width=1024, )

_ = profiler_runner(run_inference, pipe)

benchmark test

I'm using this testing scripts, some numbers below
( I don't think putting arrange on device improved anything, but there is a difference from 0.30.1-patch so should look into that)

main before the lastest commit (with graph break warning)

# Execution time: 2.287 sec
# Memory: 22.805 gib

main with latest commit (#9307)

# Execution time: 2.256 sec
# Memory: 22.346 gib

in the PR (do arange on gpu)

# Execution time: 2.269 sec
# Memory: 22.348 gib

0.30.1-patch (using the original flux rotary embeds, before this PR #9074)

# Execution time: 2.226 sec
# Memory: 22.346 gib

testing script

import os os.environ['TORCH_LOGS'] = 'graph_breaks'

import torch import torch.utils.benchmark as benchmark import gc

import time

torch.set_float32_matmul_precision("high") torch._inductor.conv_1x1_as_mm = True torch._inductor.coordinate_descent_tuning = True torch._inductor.epilogue_fusion = False torch._inductor.coordinate_descent_check_all_directions = True

import diffusers from platform import python_version from diffusers import DiffusionPipeline

print(diffusers.version) print(torch.version) print(python_version())

def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}, num_threads=torch.get_num_threads(), ) return f"{(t0.blocked_autorange().mean):.3f}"

def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}"

def flush(): """Wipes off memory.""" gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats()

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda") pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt_embeds = torch.load("flux_prompt_embeds.pt") pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")

def run_inference(pipe): _ = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=5, guidance_scale=3.5, max_sequence_length=512, generator=torch.manual_seed(42), height=1024, width=1024, )

flush()

for _ in range(5): run_inference(pipe)

flush()

time = benchmark_fn(run_inference, pipe) memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. print(f" Execution time: {time} sec") print(f" Memory: {memory} gib")