Make group offloading compatible with torch.compile() by sayakpaul · Pull Request #11605 · huggingface/diffusers (original) (raw)
What does this PR do?
On H100 with Wan 14B, we get:
Compile: False offloading: False Latency: 4601.803 ms (median over 1 runs)
Compile: True offloading: False Latency: 3766.335 ms (median over 1 runs)
Compile: False offloading: True Latency: 4918.042 ms (median over 1 runs)
Compile: True offloading: True Latency: 4109.500 ms (median over 1 runs)
On RTX 4090, we get:
Compile: False offloading: True Latency: 13658.121 ms (median over 1 runs)
Compile: True offloading: True Latency: 11583.754 ms (median over 1 runs)
Code:
Expand
from diffusers import AutoModel import torch torch.set_grad_enabled(False) from torch.utils import benchmark import argparse
torch._dynamo.config.cache_size_limit = 10000
def get_input_dict(**device_dtype_kwargs): # height: 480 # width: 832 # num_frames: 81 # max_sequence_length: 512 hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs) encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs) timestep = torch.tensor([1.0], **device_dtype_kwargs)
return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--compile", action="store_true") parser.add_argument("--go", action="store_true") return parser.parse_args()
if name == "main": args = get_parser() transformer = AutoModel.from_pretrained( "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16 ) if not args.go: transformer.cuda() else: group_offload_kwargs={ "onload_device": "cuda", "offload_device": "cpu", "offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True, "non_blocking": True, } transformer.enable_group_offload(**group_offload_kwargs) if args.compile: transformer.compile()
input_kwargs = {"dtype": torch.bfloat16, "device": "cuda"} if not args.go else {"dtype": torch.bfloat16}
input_dict = get_input_dict(**input_kwargs)
for _ in range(4):
_ = transformer(**input_dict)
latency_timer = benchmark.Timer(
stmt="transformer(**input_dict)",
setup="from __main__ import transformer, input_dict",
num_threads=1,
label="Go+Compilation inference latency",
)
latency_result = latency_timer.blocked_autorange(min_run_time=1)
latency_ms = latency_result.median * 1e3
print(f"Compile: {args.compile} offloading: {args.go}")
print(f"Latency: {latency_ms:.3f} ms (median over {len(latency_result.times)} runs)")As one would expect, using streams for overlapping compute with communication yields the best trade-off. Using record_stream=True gives additional speedups.