benchmark_attention_clip.py (original) (raw)
import sys
import time
import requests
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import torch
import transformers
print("\n## Environment:\n")
print("Python version:", sys.version)
print("Transformers version:", transformers.__version__)
print("Torch version:", torch.__version__)
if torch.cuda.is_available():
print("GPU:", torch.cuda.get_device_name(0))
@torch.no_grad()
def get_model_iteration_time(model, inputs, device, min_iterations=100, min_benchmark_time=4, warm_up_steps=10):
with torch.autocast(device):
for _ in range(warm_up_steps):
model(**inputs)
timings = []
iterations = 0
benchmark_time = 0
torch.cuda.synchronize()
while benchmark_time < min_benchmark_time or iterations < min_iterations:
for _ in range(10):
start_time = time.time()
_ = model(**inputs)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time
# store the time
timings.append(elapsed_time)
# update the benchmark time and iterations
benchmark_time += elapsed_time
iterations += 1
median_time = np.median(timings)
ci = 1.96 * np.array(timings).std() / np.sqrt(len(timings))
return median_time, ci
def prepare_inputs(processor, image_batch_size=None, text_batch_size=None):
# loading image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
images = [image] * image_batch_size if image_batch_size is not None else None
texts = ["a photo of 2 cats"] * text_batch_size if text_batch_size is not None else None
inputs = processor(text=texts, images=images, padding="max_length", return_tensors="pt")
return inputs
def format_df(df) -> str:
# format float numbers
for column in df.columns:
if "CI" in column and "%" in column:
df[column] = df[column].apply(lambda x: f"±{x:.1f}%")
else:
df[column] = df[column].apply(lambda x: f"{x:.3f}")
# rename columns
columns_mapping = {
"image_batch_size": "Image batch size",
"text_batch_size": "Num text labels",
"Eager": "Eager (s/iter)",
"FA2": "FA2 (s/iter)",
"SDPA": "SDPA (s/iter)",
}
for column_name, new_column_name in columns_mapping.items():
if column_name in df.columns:
df = df.rename(columns={column_name: new_column_name})
# format as markdown table
markdown = df.to_markdown(index=False)
return markdown
def benchmark(models_dict, processor, device, image_batch_sizes=None, text_batch_sizes=None, n_iterations=100):
image_batch_sizes = image_batch_sizes or [None]
text_batch_sizes = text_batch_sizes or [None]
cases = list(itertools.product(image_batch_sizes, text_batch_sizes))
results = []
for image_batch_size, text_batch_size in tqdm(cases):
inputs = prepare_inputs(processor, image_batch_size, text_batch_size).to(device)
step_results = {}
if image_batch_size is not None:
step_results["image_batch_size"] = image_batch_size
if text_batch_size is not None:
step_results["text_batch_size"] = text_batch_size
for attn_name, model in models_dict.items():
mean_time, confidence_interval = get_model_iteration_time(
model, inputs, device, min_iterations=n_iterations, min_benchmark_time=4
)
step_results[f"{attn_name}"] = mean_time
confidence_interval_percent = (confidence_interval / mean_time) * 100
step_results[f"{attn_name} CI, %"] = confidence_interval_percent
if attn_name != "Eager":
step_results[f"{attn_name} speedup"] = step_results["Eager"] / mean_time
results.append(step_results)
df = pd.DataFrame(results)
markdown = format_df(df)
return markdown
def load_models(model_class, checkpoint, dtype, device):
models_dict = {
"Eager": model_class.from_pretrained(checkpoint, attn_implementation="eager", torch_dtype=dtype, device_map=device).eval()
}
if model_class._supports_flash_attn_2:
models_dict["FA2"] = model_class.from_pretrained(checkpoint, attn_implementation="flash_attention_2", torch_dtype=dtype, device_map=device).eval()
if model_class._supports_sdpa:
models_dict["SDPA"] = model_class.from_pretrained(checkpoint, attn_implementation="sdpa", torch_dtype=dtype, device_map=device).eval()
return models_dict
if __name__ == "__main__":
import argparse
from transformers import AutoProcessor, CLIPModel, CLIPTextModel, CLIPVisionModel
parser = argparse.ArgumentParser()
parser.add_argument("--n_iterations", type=int, default=100)
parser.add_argument("--checkpoint", type=str, default="openai/clip-vit-large-patch14")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="float16")
args = parser.parse_args()
benchmark_multimodal = True
benchmark_text = True
benchmark_vision = True
dtype = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}[args.dtype]
processor = AutoProcessor.from_pretrained(args.checkpoint)
print("\n## Benchmark results\n")
# ---------------------------
# Multi-modal model
# ---------------------------
if benchmark_multimodal:
models_dict = load_models(CLIPModel, args.checkpoint, dtype, args.device)
result = benchmark(
models_dict,
processor,
image_batch_sizes=[1, 4, 16, 32],
text_batch_sizes=[4, 16, 32, 64],
device=args.device,
n_iterations=args.n_iterations,
)
print(f"\n### {CLIPModel.__name__}\n")
print(result)
print()
# ---------------------------
# Text model
# ---------------------------
if benchmark_text:
models_dict = load_models(CLIPTextModel, args.checkpoint, dtype, args.device)
result = benchmark(
models_dict,
processor,
text_batch_sizes=[4, 16, 32, 64, 128],
device=args.device,
n_iterations=args.n_iterations,
)
print(f"\n### {CLIPTextModel.__name__}\n")
print(result)
print()
# ---------------------------
# Vision model
# ---------------------------
if benchmark_vision:
models_dict = load_models(CLIPVisionModel, args.checkpoint, dtype, args.device)
result = benchmark(
models_dict,
processor,
image_batch_sizes=[1, 4, 16, 32],
device=args.device,
n_iterations=args.n_iterations,
)
print(f"\n### {CLIPVisionModel.__name__}\n")
print(result)
print()