benchmark_attention_clip.py (original) (raw)

import sys

import time

import requests

import itertools

import numpy as np

import pandas as pd

from tqdm import tqdm

from PIL import Image

import torch

import transformers

print("\n## Environment:\n")

print("Python version:", sys.version)

print("Transformers version:", transformers.__version__)

print("Torch version:", torch.__version__)

if torch.cuda.is_available():

print("GPU:", torch.cuda.get_device_name(0))

@torch.no_grad()

def get_model_iteration_time(model, inputs, device, min_iterations=100, min_benchmark_time=4, warm_up_steps=10):

with torch.autocast(device):

for _ in range(warm_up_steps):

model(**inputs)

timings = []

iterations = 0

benchmark_time = 0

torch.cuda.synchronize()

while benchmark_time < min_benchmark_time or iterations < min_iterations:

for _ in range(10):

start_time = time.time()

_ = model(**inputs)

torch.cuda.synchronize()

end_time = time.time()

elapsed_time = end_time - start_time

# store the time

timings.append(elapsed_time)

# update the benchmark time and iterations

benchmark_time += elapsed_time

iterations += 1

median_time = np.median(timings)

ci = 1.96 * np.array(timings).std() / np.sqrt(len(timings))

return median_time, ci

def prepare_inputs(processor, image_batch_size=None, text_batch_size=None):

# loading image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

image = Image.open(requests.get(url, stream=True).raw)

images = [image] * image_batch_size if image_batch_size is not None else None

texts = ["a photo of 2 cats"] * text_batch_size if text_batch_size is not None else None

inputs = processor(text=texts, images=images, padding="max_length", return_tensors="pt")

return inputs

def format_df(df) -> str:

# format float numbers

for column in df.columns:

if "CI" in column and "%" in column:

df[column] = df[column].apply(lambda x: f"±{x:.1f}%")

else:

df[column] = df[column].apply(lambda x: f"{x:.3f}")

# rename columns

columns_mapping = {

"image_batch_size": "Image batch size",

"text_batch_size": "Num text labels",

"Eager": "Eager (s/iter)",

"FA2": "FA2 (s/iter)",

"SDPA": "SDPA (s/iter)",

}

for column_name, new_column_name in columns_mapping.items():

if column_name in df.columns:

df = df.rename(columns={column_name: new_column_name})

# format as markdown table

markdown = df.to_markdown(index=False)

return markdown

def benchmark(models_dict, processor, device, image_batch_sizes=None, text_batch_sizes=None, n_iterations=100):

image_batch_sizes = image_batch_sizes or [None]

text_batch_sizes = text_batch_sizes or [None]

cases = list(itertools.product(image_batch_sizes, text_batch_sizes))

results = []

for image_batch_size, text_batch_size in tqdm(cases):

inputs = prepare_inputs(processor, image_batch_size, text_batch_size).to(device)

step_results = {}

if image_batch_size is not None:

step_results["image_batch_size"] = image_batch_size

if text_batch_size is not None:

step_results["text_batch_size"] = text_batch_size

for attn_name, model in models_dict.items():

mean_time, confidence_interval = get_model_iteration_time(

model, inputs, device, min_iterations=n_iterations, min_benchmark_time=4

)

step_results[f"{attn_name}"] = mean_time

confidence_interval_percent = (confidence_interval / mean_time) * 100

step_results[f"{attn_name} CI, %"] = confidence_interval_percent

if attn_name != "Eager":

step_results[f"{attn_name} speedup"] = step_results["Eager"] / mean_time

results.append(step_results)

df = pd.DataFrame(results)

markdown = format_df(df)

return markdown

def load_models(model_class, checkpoint, dtype, device):

models_dict = {

"Eager": model_class.from_pretrained(checkpoint, attn_implementation="eager", torch_dtype=dtype, device_map=device).eval()

}

if model_class._supports_flash_attn_2:

models_dict["FA2"] = model_class.from_pretrained(checkpoint, attn_implementation="flash_attention_2", torch_dtype=dtype, device_map=device).eval()

if model_class._supports_sdpa:

models_dict["SDPA"] = model_class.from_pretrained(checkpoint, attn_implementation="sdpa", torch_dtype=dtype, device_map=device).eval()

return models_dict

if __name__ == "__main__":

import argparse

from transformers import AutoProcessor, CLIPModel, CLIPTextModel, CLIPVisionModel

parser = argparse.ArgumentParser()

parser.add_argument("--n_iterations", type=int, default=100)

parser.add_argument("--checkpoint", type=str, default="openai/clip-vit-large-patch14")

parser.add_argument("--device", type=str, default="cuda")

parser.add_argument("--dtype", type=str, default="float16")

args = parser.parse_args()

benchmark_multimodal = True

benchmark_text = True

benchmark_vision = True

dtype = {

"float16": torch.float16,

"float32": torch.float32,

"bfloat16": torch.bfloat16,

}[args.dtype]

processor = AutoProcessor.from_pretrained(args.checkpoint)

print("\n## Benchmark results\n")

# ---------------------------

# Multi-modal model

# ---------------------------

if benchmark_multimodal:

models_dict = load_models(CLIPModel, args.checkpoint, dtype, args.device)

result = benchmark(

models_dict,

processor,

image_batch_sizes=[1, 4, 16, 32],

text_batch_sizes=[4, 16, 32, 64],

device=args.device,

n_iterations=args.n_iterations,

)

print(f"\n### {CLIPModel.__name__}\n")

print(result)

print()

# ---------------------------

# Text model

# ---------------------------

if benchmark_text:

models_dict = load_models(CLIPTextModel, args.checkpoint, dtype, args.device)

result = benchmark(

models_dict,

processor,

text_batch_sizes=[4, 16, 32, 64, 128],

device=args.device,

n_iterations=args.n_iterations,

)

print(f"\n### {CLIPTextModel.__name__}\n")

print(result)

print()

# ---------------------------

# Vision model

# ---------------------------

if benchmark_vision:

models_dict = load_models(CLIPVisionModel, args.checkpoint, dtype, args.device)

result = benchmark(

models_dict,

processor,

image_batch_sizes=[1, 4, 16, 32],

device=args.device,

n_iterations=args.n_iterations,

)

print(f"\n### {CLIPVisionModel.__name__}\n")

print(result)

print()