Make group offloading compatible with torch.compile() by sayakpaul · Pull Request #11605 · huggingface/diffusers (original) (raw)

What does this PR do?

On H100 with Wan 14B, we get:

Compile: False offloading: False Latency: 4601.803 ms (median over 1 runs)

Compile: True offloading: False Latency: 3766.335 ms (median over 1 runs)

Compile: False offloading: True Latency: 4918.042 ms (median over 1 runs)

Compile: True offloading: True Latency: 4109.500 ms (median over 1 runs)

On RTX 4090, we get:

Compile: False offloading: True Latency: 13658.121 ms (median over 1 runs)

Compile: True offloading: True Latency: 11583.754 ms (median over 1 runs)

Code:

Expand

from diffusers import AutoModel import torch torch.set_grad_enabled(False) from torch.utils import benchmark import argparse

torch._dynamo.config.cache_size_limit = 10000

def get_input_dict(**device_dtype_kwargs): # height: 480 # width: 832 # num_frames: 81 # max_sequence_length: 512 hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs) encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs) timestep = torch.tensor([1.0], **device_dtype_kwargs)

return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}

def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--compile", action="store_true") parser.add_argument("--go", action="store_true") return parser.parse_args()

if name == "main": args = get_parser() transformer = AutoModel.from_pretrained( "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16 ) if not args.go: transformer.cuda() else: group_offload_kwargs={ "onload_device": "cuda", "offload_device": "cpu", "offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True, "non_blocking": True, } transformer.enable_group_offload(**group_offload_kwargs) if args.compile: transformer.compile()

input_kwargs = {"dtype": torch.bfloat16, "device": "cuda"} if not args.go else {"dtype": torch.bfloat16}
input_dict = get_input_dict(**input_kwargs)

for _ in range(4):  
    _ = transformer(**input_dict)

latency_timer = benchmark.Timer(
    stmt="transformer(**input_dict)",
    setup="from __main__ import transformer, input_dict",
    num_threads=1,
    label="Go+Compilation inference latency",
)

latency_result = latency_timer.blocked_autorange(min_run_time=1)
latency_ms = latency_result.median * 1e3
print(f"Compile: {args.compile} offloading: {args.go}")
print(f"Latency: {latency_ms:.3f} ms (median over {len(latency_result.times)} runs)")

As one would expect, using streams for overlapping compute with communication yields the best trade-off. Using record_stream=True gives additional speedups.