GitHub - RobertCsordas/switchhead (original) (raw)

SwitchHead: Accelerating Transformers with Mixture-of-Experts Attention

switchhead

Official implementation of the SwitchHead attention from our NeurIPS 2024 paper.

This repository is an user-friendly implementation of SwitchHead. For the training code, please refer to https://github.com/robertcsordas/moe_attention.

This implementation relies on the CVMM Triton kernel from sigma\sigmasigma-MoE.

Example

import torch import switchhead import math from typing import Tuple, Optional

class SwitchHeadSelfAttention(torch.nn.Module): def init(self, d_model: int, *args, **kwargs): super().init() self.norm = torch.nn.LayerNorm(d_model) self.attention = switchhead.SwitchHeadRope(d_model, *args, **kwargs)

def forward(self, x: torch.Tensor, mask: Optional[switchhead.AttentionMask] = None, kv_cache: switchhead.KVCache = None) -> Tuple[torch.Tensor, switchhead.KVCache]:
    xn = self.norm(x)
    res, kv_cache = self.attention(xn, xn, xn, mask=mask)
    return x + res, kv_cache

243M param model from the paper.

batch_size = 8 context_window = 1024 d_model = 1024 n_layers = 18

x = torch.randn(batch_size, context_window, d_model).cuda()

RoPE example (default)

attention = SwitchHeadSelfAttention(d_model, n_heads=4, n_experts=4, d_head=100, init_scale=1/math.sqrt(n_layers)).cuda() out, _ = attention(x)

print(out.shape)

A simple example can be found in example.py.

Usage

We provide two versions:

SwitchHead does not have an internal residual connection or layernorm. This is to provide greater flexibiltity for customization. It also requires passign individual tensors for q, k, v projections. See example.py or the example above to see how to use it as a simple self attention.

The signature of the init function of the RoPE version is as follows:

def init(self, d_model: int, n_heads: int, n_experts: int, dropout: float = 0.0, d_head: Optional[int] = None, expert_dropout: float = 0.0, moe_k: int = 2, init_scale: float = 1.0, rotate_fraction: float = 0.5, rope_base: float = 10000):

The meaning of the arguments:

The signature of the forward function:

def forward(self, q_src: torch.Tensor, k_src: torch.Tensor, v_src: torch.Tensor, mask: Optional[AttentionMask], kv_cache: KVCache = None) -> Tuple[torch.Tensor, KVCache]:

The meaning of the arguments:

The forward pass returns a tuple of (output, update kv cache). The updated KV cache is None if the argument kv_cache was None to save memory. Otherwise it can be fed as the kv_cache in the next forward pass.

The AttentionMask has two optional boolean fields. True if to be removed. If None, they are ignored.

torch.compile() support

torch.compile() is supported with PyTorch >= 2.3.

Project structure

├───switchhead - the SwitchHead attention implementation. Copy this to your project.
│    ├─  cvmm.py - the CVMM Triton kernel.
│    └─  switchhead.py - the implementation of SwitchHead
│
├───example.py - an example forward using both variants pass.
├───LICENSE - MIT License.
└───README.md - this documentation.

Installation Instruction

The only external dependencies are PyTorch (>= 2.1) and Triton (>= 2.1). Copy the switchhead directory to your project and import it as shown in the examples above.

pip3 install -r requirements.txt

Known issues

Triton seems to be broken on Volta GPUs when using float16 starting from PyTorch 2.2 to 2.3 (see github issue). Until the PyTorch team does not fix the issue, please downgrade to PyTorch 2.1 or disable AMP if you have Volta GPUs.