test_cogvideox_torch_compile.py (original) (raw)

import argparse

import time

import torch

torch.set_float32_matmul_precision("high")

from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler

from diffusers.utils import export_to_gif

def load_pipeline(use_compile: bool = False):

id = "THUDM/CogVideoX-2b"

pipe = CogVideoXPipeline.from_pretrained(

id,

torch_dtype=torch.float16

).to("cuda")

pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

pipe.set_progress_bar_config(disable=True)

if use_compile:

torch._inductor.config.conv_1x1_as_mm = True

torch._inductor.config.coordinate_descent_tuning = True

torch._inductor.config.epilogue_fusion = False

torch._inductor.config.coordinate_descent_check_all_directions = True

pipe.transformer.to(memory_format=torch.channels_last)

# pipe.vae.to(memory_format=torch.channels_last_3d) # does not work

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)

# pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune") # does not work due to torch._dynamo.exc.InternalTorchDynamoError:

# Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.

# Stack trace: File "/home/aryan/work/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py", line 137,

# in forward self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone(). To prevent overwriting, clone the tensor outside of torch.compile()

# or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

# TODO: fix in future

return pipe

def run_benchmark(num_warmups: int = 2, num_repeats: int = 10, use_compile: bool = False):

pipe = load_pipeline(use_compile)

prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."

for _ in range(num_warmups):

_ = pipe(

prompt=prompt,

num_frames=48,

num_inference_steps=50,

guidance_scale=6,

generator=torch.manual_seed(42),

)

start = time.time()

for _ in range(num_repeats):

video = pipe(

prompt=prompt,

num_frames=48,

num_inference_steps=50,

guidance_scale=6,

generator=torch.manual_seed(42),

).frames[0]

end = time.time()

avg_inference_time = (end - start) / num_repeats

print(f"Average inference time: {avg_inference_time:.3f} seconds.")

export_to_gif(video, f"cogvideox_compile_test_{args.use_compile}.gif")

if __name__ == "__main__":

parser = argparse.ArgumentParser()

parser.add_argument("--num_warmups", type=int, default=2)

parser.add_argument("--num_repeats", type=int, default=10)

parser.add_argument("--use_compile", action="store_true", default=False)

args = parser.parse_args()

run_benchmark(args.num_warmups, args.num_repeats, args.use_compile)