GitHub - fla-org/flash-bidirectional-linear-attention: Triton implement of bi-directional (non-causal) linear attention (original) (raw)

Flash Bi-directional Linear Attention

The aim of this repository is to implement bi-directional linear attention for non-causal modeling using Triton.

image

This project is currently maintained by an individual and remains a work in progress. As the maintainer is still in the early stages of learning Triton, many implementations may not be optimal. Contributions and suggestions are welcome!

Update

Models

Roughly sorted according to the timeline supported in FBi-LA

Year Model Title Paper Code fla impl
2024 Linfusion LinFusion: 1 GPU, 1 Minute, 16K Image arxiv official code
2024 MLLA Demystify Mamba in Vision: A Linear Attention Perspective arxiv official code
2023 Focused-LA FLatten Transformer: Vision Transformer using Focused Linear Attention arxiv official code
2025 PolaFormer PolaFormer: Polarity-aware Linear Attention for Vision Transformers arxiv official code

More models will be implemented gradually.

Usage

Installation

git clone https://github.com/fla-org/flash-bidirectional-linear-attention.git pip install -e flash-bidirectional-linear-attention/.

Integrated Models

This library has integrated some models, which can be called directly. Taking LinFusion as an example:

import torch from diffusers import AutoPipelineForText2Image from fbi_la.models import LinFusion

sd_repo = "Lykon/dreamshaper-8"

pipeline = AutoPipelineForText2Image.from_pretrained( sd_repo, torch_dtype=torch.float16, variant="fp16" ).to(torch.device("cuda"))

linfusion = LinFusion.construct_for(pipeline)

image = pipeline( "An astronaut floating in space. Beautiful view of the stars and the universe in the background.", generator=torch.manual_seed(123) ).images[0]

Benchmarks

Tested on an A800 80G GPU.

B8-H16-D64: T torch_fwd triton_fwd torch_bwd triton_bwd 0 128.0 0.063488 0.049152 0.798720 0.651264 1 256.0 0.080896 0.056320 0.796672 0.625664 2 512.0 0.111616 0.058368 0.798720 0.630784 3 1024.0 0.169984 0.090112 0.864256 0.719872 4 2048.0 0.300032 0.151552 1.624064 0.702464 5 4096.0 0.532480 0.276480 3.058176 1.324032 6 8192.0 1.005568 0.521216 5.880320 2.556928 7 16384.0 1.924608 0.980992 11.540992 5.022208

image

TODO

Acknowledgments

Thanks to the following repositories for their inspiration: