[Feat] TaylorSeer Cache by toilaluan · Pull Request #12648 · huggingface/diffusers (original) (raw)

Comparison: Baseline, Baseline-22steps, FBCache, TaylorSeer Cache

FLUX.1-dev

Memory & Speed Metrics (GPU: H100, 50 steps, compiled)

Prompt Index Variant Load Time (s) Load Memory (GB) Compile Time (s) Warmup Time (s) Main Time (s) Peak Memory (GB)
0 baseline 5.812418 31.437537 1.162642 3.940589 11.744348 33.851628
0 baseline(steps=22) 5.445316 31.469763 0.054121 2.662160 5.271367 33.851628
0 firstblock(threshold=0.05) 5.618118 31.469763 0.053683 30.769011 8.686095 33.928777
0 taylorseer(max_order=1, cache_interval=5, disable_cache_before_step=10) 5.487217 31.469763 0.051885 59.841501 4.865684 33.852117

Visual Outputs

Baseline Baseline-22steps FBCache TaylorSeer Cache
Baseline Baseline 22 steps FBCache TaylorCache
Baseline Baseline-22steps FBCache TaylorSeer Cache
Baseline Baseline 22 steps FBCache TaylorCache

Analyze TaylorSeer Configurations

Memory & Speed (no compile)

Prompt Index Variant Steps Max Order Cache Interval Load Time (s) Load Memory (GB) Compile Time (s) Warmup Time (s) Main Time (s) Peak Memory (GB) Speedup vs baseline_50 Speedup vs baseline_22
0 baseline_50 50 N/A N/A 5.68 31.44 0.00 2.38 15.32 33.85 1.00x 1.00x
0 baseline_22 22 N/A N/A 5.39 31.47 0.00 1.70 6.86 33.85 2.23x 1.00x
0 taylor_o0_ci5 50 0 5 5.37 31.47 0.00 1.71 5.70 33.85 2.69x 1.20x
0 taylor_o0_ci10 50 0 10 5.38 31.47 0.00 1.71 4.64 33.85 3.30x 1.48x
0 taylor_o0_ci15 50 0 15 5.44 31.47 0.00 1.71 4.40 33.85 3.48x 1.56x
0 taylor_o1_ci5 50 1 5 5.67 31.47 0.00 1.71 5.70 33.85 2.69x 1.20x
0 taylor_o1_ci8 50 1 8 5.68 31.47 0.00 1.71 4.88 33.85 3.14x 1.41x
0 taylor_o1_ci10 50 1 10 5.68 31.47 0.00 1.71 4.64 33.85 3.30x 1.48x
0 taylor_o1_ci15 50 1 15 5.68 31.47 0.00 1.71 4.40 33.85 3.48x 1.56x
0 taylor_o2_ci5 50 2 5 5.66 31.47 0.00 1.71 5.69 33.85 2.69x 1.21x
0 taylor_o2_ci10 50 2 10 5.73 31.47 0.00 1.70 4.64 33.85 3.30x 1.48x
0 taylor_o2_ci15 50 2 15 5.36 31.47 0.00 1.70 4.40 33.85 3.48x 1.56x

Visual Comparison (o1,ci5 means max_order=1, cache_interval=5)

image

Reproduce Code

  1. Baselines Vs. TaylorSeer variants Details

import torch from diffusers import FluxPipeline, TaylorSeerCacheConfig import time import os import matplotlib.pyplot as plt import pandas as pd import gc

Set dynamo config

import torch._dynamo as dynamo dynamo.config.recompile_limit = 200

prompts = [ "Black cat hiding behind a watermelon slice, professional studio shot, bright red and turquoise background with summer mystery vibe", ]

Create output folder

os.makedirs("outputs", exist_ok=True)

============================================================================

CONFIGURATION SECTION - Easily modify these parameters

============================================================================

Fixed config parameters (applied to all TaylorSeer configs)

FIXED_CONFIG = { 'disable_cache_before_step': 10, 'taylor_factors_dtype': torch.bfloat16, 'use_lite_mode': True }

Variable parameters to test - modify these as needed

Format: (max_order, cache_interval)

TAYLOR_CONFIGS = [ (0, 5), # max_order=0, cache_interval=5 (0, 10), (0, 15), (1, 5), # max_order=1, cache_interval=5 (1, 8), # max_order=1, cache_interval=6 (1, 10), # max_order=1, cache_interval=7 (1, 15), # max_order=1, cache_interval=10 (2, 5), # max_order=2, cache_interval=5 (2, 10), # max_order=2, cache_interval=6 (2, 15), ]

Baseline configurations

BASELINES = [ {'name': 'baseline_50', 'steps': 50}, {'name': 'baseline_22', 'steps': 22}, ]

Main inference steps for TaylorSeer variants

MAIN_STEPS = 50 WARMUP_STEPS = 5

============================================================================

Build TaylorSeer configs

taylor_configs = {} for max_order, cache_interval in TAYLOR_CONFIGS: config_name = f'taylor_o{max_order}_ci{cache_interval}' taylor_configs[config_name] = TaylorSeerCacheConfig( max_order=max_order, cache_interval=cache_interval, **FIXED_CONFIG )

Collect results

results = []

for i, prompt in enumerate(prompts): print(f"\n{'='*80}") print(f"Processing Prompt {i}: {prompt[:50]}...") print(f"{'='*80}\n")

images = {}
baseline_times = {}

# Run all baseline variants first
for baseline_config in BASELINES:
    variant = baseline_config['name']
    num_steps = baseline_config['steps']
    
    print(f"Running {variant} (steps={num_steps})...")
    
    # Clear cache before loading
    gc.collect()
    torch.cuda.empty_cache()
    
    # Load pipeline with timing
    start_load = time.time()
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    load_time = time.time() - start_load
    load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)
    
    # Compile with timing
    start_compile = time.time()
    # pipeline.transformer.compile_repeated_blocks(fullgraph=False)
    compile_time = time.time() - start_compile
    
    # Warmup with 5 steps
    gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
    start_warmup = time.time()
    _ = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=WARMUP_STEPS,
        guidance_scale=3.0,
        generator=gen_warmup
    ).images[0]
    warmup_time = time.time() - start_warmup
    
    # Main run
    gen_main = torch.Generator(device="cuda").manual_seed(181201)
    
    torch.cuda.reset_peak_memory_stats()
    start_main = time.time()
    image = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=num_steps,
        guidance_scale=3.0,
        generator=gen_main
    ).images[0]
    end_main = time.time()
    
    peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
    main_time = end_main - start_main
    
    # Save image
    image_path = f"outputs/{variant}_prompt{i}.jpg"
    image.save(image_path)
    images[variant] = image
    
    # Store baseline time
    baseline_times[variant] = main_time
    
    # Record results
    results.append({
        'Prompt Index': i,
        'Variant': variant,
        'Steps': num_steps,
        'Max Order': 'N/A',
        'Cache Interval': 'N/A',
        'Load Time (s)': f"{load_time:.2f}",
        'Load Memory (GB)': f"{load_mem_gb:.2f}",
        'Compile Time (s)': f"{compile_time:.2f}",
        'Warmup Time (s)': f"{warmup_time:.2f}",
        'Main Time (s)': f"{main_time:.2f}",
        'Peak Memory (GB)': f"{peak_mem_gb:.2f}",
        'Speedup vs baseline_50': '1.00x' if variant == 'baseline_50' else f"{baseline_times['baseline_50']/main_time:.2f}x",
        'Speedup vs baseline_22': '1.00x' if variant == 'baseline_22' else f"{baseline_times.get('baseline_22', main_time)/main_time:.2f}x"
    })
    
    print(f"  Load: {load_time:.2f}s, Compile: {compile_time:.2f}s, Warmup: {warmup_time:.2f}s")
    print(f"  Main: {main_time:.2f}s, Peak Memory: {peak_mem_gb:.2f} GB\n")
    
    # Clean up
    pipeline.to("cpu")
    del pipeline
    gc.collect()
    torch.cuda.empty_cache()
    dynamo.reset()

# TaylorSeer variants with different configurations
for config_name, tsconfig in taylor_configs.items():
    variant = config_name
    max_order = tsconfig.max_order
    cache_interval = tsconfig.cache_interval
    print(f"Running {variant} (max_order={max_order}, cache_interval={cache_interval})...")
    
    # Clear cache before loading
    gc.collect()
    torch.cuda.empty_cache()
    
    # Load pipeline with timing
    start_load = time.time()
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    load_time = time.time() - start_load
    load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)
    
    # Enable TaylorSeer cache
    pipeline.transformer.enable_cache(tsconfig)
    
    # Compile with timing
    start_compile = time.time()
    # pipeline.transformer.compile_repeated_blocks(fullgraph=False)
    compile_time = time.time() - start_compile
    
    # Warmup with 5 steps
    gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
    start_warmup = time.time()
    _ = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=WARMUP_STEPS,
        guidance_scale=3.0,
        generator=gen_warmup
    ).images[0]
    warmup_time = time.time() - start_warmup
    
    # Main run
    gen_main = torch.Generator(device="cuda").manual_seed(181201)
    
    torch.cuda.reset_peak_memory_stats()
    start_main = time.time()
    image = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=MAIN_STEPS,
        guidance_scale=3.0,
        generator=gen_main
    ).images[0]
    end_main = time.time()
    
    peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
    main_time = end_main - start_main
    speedup_50 = baseline_times['baseline_50'] / main_time
    speedup_22 = baseline_times['baseline_22'] / main_time
    
    # Save image
    image_path = f"outputs/{variant}_prompt{i}.jpg"
    image.save(image_path)
    images[variant] = image
    
    # Record results
    results.append({
        'Prompt Index': i,
        'Variant': variant,
        'Steps': MAIN_STEPS,
        'Max Order': max_order,
        'Cache Interval': cache_interval,
        'Load Time (s)': f"{load_time:.2f}",
        'Load Memory (GB)': f"{load_mem_gb:.2f}",
        'Compile Time (s)': f"{compile_time:.2f}",
        'Warmup Time (s)': f"{warmup_time:.2f}",
        'Main Time (s)': f"{main_time:.2f}",
        'Peak Memory (GB)': f"{peak_mem_gb:.2f}",
        'Speedup vs baseline_50': f"{speedup_50:.2f}x",
        'Speedup vs baseline_22': f"{speedup_22:.2f}x"
    })
    
    print(f"  Load: {load_time:.2f}s, Compile: {compile_time:.2f}s, Warmup: {warmup_time:.2f}s")
    print(f"  Main: {main_time:.2f}s, Peak Memory: {peak_mem_gb:.2f} GB")
    print(f"  Speedup vs baseline_50: {speedup_50:.2f}x, vs baseline_22: {speedup_22:.2f}x\n")
    
    # Clean up
    pipeline.to("cpu")
    del pipeline
    gc.collect()
    torch.cuda.empty_cache()
    dynamo.reset()

# Plot image comparison for this prompt (select key variants)
key_variants = ['baseline_50', 'baseline_22'] + [list(taylor_configs.keys())[j] for j in range(min(4, len(taylor_configs)))]
num_variants = len(key_variants)

fig, axs = plt.subplots(1, num_variants, figsize=(10*num_variants, 10))
if num_variants == 1:
    axs = [axs]

for j, var in enumerate(key_variants):
    if var in images:
        axs[j].imshow(images[var])
        axs[j].set_title(f"{var}", fontsize=24)
        axs[j].axis('off')

plt.tight_layout()
plt.savefig(f"outputs/comparison_prompt{i}.png", dpi=100)
plt.close()

Print results table

print("\n" + "="*140) print("BENCHMARK RESULTS") print("="*140 + "\n")

df = pd.DataFrame(results) print(df.to_string(index=False))

Save results to CSV

df.to_csv("outputs/benchmark_results.csv", index=False) print("\nResults saved to outputs/benchmark_results.csv")

Calculate and display averages per variant

print("\n" + "="*140) print("AVERAGE METRICS BY VARIANT") print("="*140 + "\n")

Convert numeric columns back to float for averaging

numeric_cols = ['Load Time (s)', 'Load Memory (GB)', 'Compile Time (s)', 'Warmup Time (s)', 'Main Time (s)', 'Peak Memory (GB)']

df_numeric = df.copy() for col in numeric_cols: df_numeric[col] = df_numeric[col].astype(float)

Group by variant and calculate means

avg_df = df_numeric.groupby('Variant')[numeric_cols + ['Steps']].mean()

Add configuration info

avg_df['Max Order'] = df.groupby('Variant')['Max Order'].first() avg_df['Cache Interval'] = df.groupby('Variant')['Cache Interval'].first()

Calculate average speedups

speedup_50_df = df.groupby('Variant')['Speedup vs baseline_50'].apply( lambda x: f"{sum(float(v.rstrip('x')) for v in x) / len(x):.2f}x" ) speedup_22_df = df.groupby('Variant')['Speedup vs baseline_22'].apply( lambda x: f"{sum(float(v.rstrip('x')) for v in x) / len(x):.2f}x" ) avg_df['Avg Speedup vs baseline_50'] = speedup_50_df avg_df['Avg Speedup vs baseline_22'] = speedup_22_df

Reorder columns

avg_df = avg_df[['Steps', 'Max Order', 'Cache Interval'] + numeric_cols + ['Avg Speedup vs baseline_50', 'Avg Speedup vs baseline_22']]

Format numeric columns

avg_df['Steps'] = avg_df['Steps'].apply(lambda x: f"{x:.0f}") for col in numeric_cols: avg_df[col] = avg_df[col].apply(lambda x: f"{x:.2f}")

print(avg_df.to_string())

Create comprehensive visualizations

fig, axes = plt.subplots(2, 2, figsize=(20, 16))

Extract data for plotting

variants = [] main_times = [] peak_memories = [] speedups_50 = [] speedups_22 = [] labels = []

for variant in df['Variant'].unique(): variant_data = df_numeric[df_numeric['Variant'] == variant] variants.append(variant) main_times.append(variant_data['Main Time (s)'].mean()) peak_memories.append(variant_data['Peak Memory (GB)'].mean())

# Calculate average speedups
speedup_50_values = df[df['Variant'] == variant]['Speedup vs baseline_50'].apply(
    lambda x: float(x.rstrip('x'))
)
speedup_22_values = df[df['Variant'] == variant]['Speedup vs baseline_22'].apply(
    lambda x: float(x.rstrip('x'))
)
speedups_50.append(speedup_50_values.mean())
speedups_22.append(speedup_22_values.mean())

# Create readable labels
if 'baseline' in variant:
    labels.append(variant)
else:
    parts = variant.split('_')
    order = parts[1].replace('o', 'O')
    ci = parts[2].replace('ci', 'CI')
    labels.append(f"{order}_{ci}")

Assign colors

colors = ['#1f77b4', '#ff7f0e'] + ['#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] * 3 colors = colors[:len(variants)]

Plot 1: Main Time Comparison

ax1 = axes[0, 0] bars1 = ax1.bar(labels, main_times, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) ax1.set_ylabel('Main Generation Time (seconds)', fontsize=12, fontweight='bold') ax1.set_xlabel('Configuration', fontsize=12, fontweight='bold') ax1.set_title('Average Generation Time Comparison', fontsize=14, fontweight='bold') ax1.grid(axis='y', alpha=0.3, linestyle='--') ax1.tick_params(axis='x', rotation=45)

Add value labels on bars

for bar, time in zip(bars1, main_times): height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height, f'{time:.2f}s', ha='center', va='bottom', fontsize=9, fontweight='bold')

Plot 2: Peak Memory Comparison

ax2 = axes[0, 1] bars2 = ax2.bar(labels, peak_memories, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) ax2.set_ylabel('Peak Memory Usage (GB)', fontsize=12, fontweight='bold') ax2.set_xlabel('Configuration', fontsize=12, fontweight='bold') ax2.set_title('Average Peak Memory Comparison', fontsize=14, fontweight='bold') ax2.grid(axis='y', alpha=0.3, linestyle='--') ax2.tick_params(axis='x', rotation=45)

Add value labels on bars

for bar, mem in zip(bars2, peak_memories): height = bar.get_height() ax2.text(bar.get_x() + bar.get_width()/2., height, f'{mem:.2f}', ha='center', va='bottom', fontsize=9, fontweight='bold')

Plot 3: Speedup vs baseline_50

ax3 = axes[1, 0] bars3 = ax3.bar(labels, speedups_50, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) ax3.axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Baseline (50 steps)') ax3.set_ylabel('Speedup Factor', fontsize=12, fontweight='bold') ax3.set_xlabel('Configuration', fontsize=12, fontweight='bold') ax3.set_title('Speedup vs Baseline (50 steps)', fontsize=14, fontweight='bold') ax3.grid(axis='y', alpha=0.3, linestyle='--') ax3.tick_params(axis='x', rotation=45) ax3.legend()

Add value labels on bars

for bar, speedup in zip(bars3, speedups_50): height = bar.get_height() ax3.text(bar.get_x() + bar.get_width()/2., height, f'{speedup:.2f}x', ha='center', va='bottom', fontsize=9, fontweight='bold')

Plot 4: Speedup vs baseline_22

ax4 = axes[1, 1] bars4 = ax4.bar(labels, speedups_22, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) ax4.axhline(y=1.0, color='orange', linestyle='--', linewidth=2, label='Baseline (22 steps)') ax4.set_ylabel('Speedup Factor', fontsize=12, fontweight='bold') ax4.set_xlabel('Configuration', fontsize=12, fontweight='bold') ax4.set_title('Speedup vs Baseline (22 steps)', fontsize=14, fontweight='bold') ax4.grid(axis='y', alpha=0.3, linestyle='--') ax4.tick_params(axis='x', rotation=45) ax4.legend()

Add value labels on bars

for bar, speedup in zip(bars4, speedups_22): height = bar.get_height() ax4.text(bar.get_x() + bar.get_width()/2., height, f'{speedup:.2f}x', ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout() plt.savefig("outputs/metrics_comparison.png", dpi=150, bbox_inches='tight') plt.close()

print("\n" + "="*140) print("Benchmark completed! Check the outputs/ folder for results and visualizations.") print("="*140)

  1. Baseline, TaylorSeer, FirstBlockCache Details

import torch from diffusers import FluxPipeline, TaylorSeerCacheConfig, FirstBlockCacheConfig, FasterCacheConfig import time import os import matplotlib.pyplot as plt import pandas as pd import gc # Added for explicit garbage collection

Set dynamo config

import torch._dynamo as dynamo dynamo.config.recompile_limit = 200

prompts = [ "Soaking wet tiger cub taking shelter under a banana leaf in the rainy jungle, close up photo", ]

Create output folder

os.makedirs("outputs", exist_ok=True)

Define cache configs

fbconfig = FirstBlockCacheConfig( threshold=0.05 )

tsconfig = TaylorSeerCacheConfig( cache_interval=5, max_order=1, disable_cache_before_step=10, disable_cache_after_step=48, taylor_factors_dtype=torch.bfloat16, use_lite_mode=True )

Collect results

results = []

for i, prompt in enumerate(prompts): images = {} for variant in ['baseline', 'baseline_reduce', 'firstblock', 'taylor']: # Clear cache before loading gc.collect() torch.cuda.empty_cache()

    # Load pipeline with timing
    start_load = time.time()
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    load_time = time.time() - start_load
    load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
    
    # Enable cache if applicable
    if variant == 'firstblock':
        pipeline.transformer.enable_cache(fbconfig)
    elif variant == 'taylor':
        pipeline.transformer.enable_cache(tsconfig)
    # No cache for baseline and baseline_reduce
    
    # Compile with timing
    start_compile = time.time()
    pipeline.transformer.compile_repeated_blocks(fullgraph=False)
    compile_time = time.time() - start_compile
    
    # Warmup with 10 steps
    gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
    start_warmup = time.time()
    _ = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=5,
        guidance_scale=3.0,
        generator=gen_warmup
    ).images[0]
    warmup_time = time.time() - start_warmup
    
    # Main run
    steps = 22 if variant == 'baseline_reduce' else 50
    
    gen_main = torch.Generator(device="cuda").manual_seed(181201)
    
    torch.cuda.reset_peak_memory_stats()
    start_main = time.time()
    image = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=steps,
        guidance_scale=3.0,
        generator=gen_main
    ).images[0]
    end_main = time.time()
    
    peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB
    main_time = end_main - start_main
    
    # Save image
    image_path = f"outputs/{variant}_prompt{i}.jpg"
    image.save(image_path)
    images[variant] = image
    
    # Record results
    results.append({
        'Prompt Index': i,
        'Variant': variant,
        'Load Time (s)': load_time,
        'Load Memory (GB)': load_mem_gb,
        'Compile Time (s)': compile_time,
        'Warmup Time (s)': warmup_time,
        'Main Time (s)': main_time,
        'Peak Memory (GB)': peak_mem_gb
    })
    
    # Clean up
    pipeline.to("cpu")
    del pipeline
    gc.collect()  # Force garbage collection
    torch.cuda.empty_cache()  # Empty CUDA cache after GC
    dynamo.reset()  # Reset Dynamo cache (harmless even if not compiling)

# Plot image comparison for this prompt
fig, axs = plt.subplots(1, 4, figsize=(40, 10))
variants_order = ['baseline', 'baseline_reduce', 'firstblock', 'taylor']
for j, var in enumerate(variants_order):
    axs[j].imshow(images[var])
    axs[j].set_title(var)
    axs[j].axis('off')
plt.tight_layout()
plt.savefig(f"outputs/comparison_prompt{i}.png")
plt.close()

Print speed and memory comparison as a table

df = pd.DataFrame(results) print("Speed and Memory Comparison:") print(df.to_string(index=False))

Optionally, plot bar charts for averages

avg_df = df.groupby('Variant').mean().reset_index() fig, ax1 = plt.subplots(figsize=(10, 6)) ax1.bar(avg_df['Variant'], avg_df['Main Time (s)'], color='b', label='Main Time (s)') ax1.set_ylabel('Main Time (s)') ax2 = ax1.twinx() ax2.plot(avg_df['Variant'], avg_df['Peak Memory (GB)'], color='r', marker='o', label='Peak Memory (GB)') ax2.set_ylabel('Peak Memory (GB)') fig.suptitle('Average Speed and Memory Comparison') fig.legend() plt.savefig("outputs/metrics_comparison.png") plt.close()