Add AITER attention backend by lauri9 · Pull Request #12549 · huggingface/diffusers (original) (raw)
What does this PR do?
AITER is AMD’s centralized repository to support high performance AI operators such as attention kernels for AMD ROCm enabled accelerators. This PR adds support for FlashAttention through AITER by introducing a new attention backend.
Test code for Flux inference below. Requires installation of aiter>=0.15.0 and a supported ROCm enabled accelerator.
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, attention_backend
model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
transformer.set_attention_backend("aiter")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")
prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")
We are interested in following up this PR by eventually also enabling AITER backend support for context parallelism across multiple devices as the feature becomes more mature.
Before submitting
- This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- Did you read the contributor guideline?
- Did you read our philosophy doc (important for complex PRs)?
- Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings. - Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc: @sayakpaul @DN6 for review and any comments