GitHub - NVlabs/GatedDeltaNet: [ICLR 2025] Official PyTorch Implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule (original) (raw)

Gated Delta Networks: Improving Mamba2 with Delta Rule

nvidia-deltanet-badge

Official PyTorch implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule (ICLR '25).

Star on GitHub

Songlin Yang,Jan Kautz andAli Hatamizadeh.

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing

For additional functionalities, such as varlen training and inference support, see FLA implementation.

📢 Latest Updates


❓ Frequently Asked Questions (FAQ)

1️⃣ Can I use Gated DeltaNet directly from FLA?

Yes! You can import the Gated DeltaNet block directly from FLA. The following script demonstrates how to do so using either FLA or our repository:

USE_FLA = True import torch if USE_FLA: ... from fla.layers import GatedDeltaNet else: ... from .gated_delta_net import GatedDeltaNet

bs, num_heads, seq_len, hidden_size = 16, 4, 2048, 512 gated_deltanet = GatedDeltaNet(hidden_size=hidden_size, num_heads=num_heads, mode='chunk').bfloat16().cuda() gated_deltanet GatedDeltaNet( (silu): SiLU() (q_proj): Linear(in_features=512, out_features=1024, bias=False) (k_proj): Linear(in_features=512, out_features=1024, bias=False) (v_proj): Linear(in_features=512, out_features=2048, bias=False) (b_proj): Linear(in_features=512, out_features=4, bias=False) (a_proj): Linear(in_features=512, out_features=4, bias=False) (q_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu) (k_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu) (v_conv1d): ShortConvolution(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048, bias=False, activation=silu) (g_proj): Linear(in_features=512, out_features=2048, bias=False) (o_norm): FusedRMSNormSwishGate(512, eps=1e-05) (o_proj): Linear(in_features=2048, out_features=512, bias=False) ) x = torch.randn(bs, seq_len, hidden_size).bfloat16().cuda() y, _, _ = gated_deltanet(x) y.shape torch.Size([16, 2048, 512])


2️⃣ What is the difference between the FLA Gated DeltaNet kernels and the NVLabs implementation?

FLA kernels are faster and also support variable-length (varlen) training. We strongly recommend using FLA for better performance.

For reference, we also provide FLA-based kernels in this repository. You can find the optimized FLA Gated DeltaNet kernels here.


3️⃣ Will you release the pretrained model weights?

No, we only provide code implementations.


4️⃣ The dataloader in this repository is designed for SlimPajama-672B, but your models were trained on FineWeb-Edu. Why is that, and should I expect similar results?

For the code release, we used the original Samba repository and included the SlimPajama-672B dataloader to maintain consistency.

Our experiments confirm that SlimPajama-672B produces similar results and trends to those reported in our paper. You can expect comparable performance.


5️⃣ Any guidance for evaluating the models?

Since this codebase is primarily adapted from the Samba codebase, which is designed mainly for training, evaluation can be inconvenient. Notably, Samba codebase lacks generation utilities required for many generation-based evaluation tasks.

We recommend first converting your trained model weights to Hugging Face format provided in the FLA repo. Once converted, you can leverage FLA for streamlined evaluation.

🌟 Why Gated DeltaNet?

Gated DeltaNet introduces a novel approach to linear transformers by combining:

Architecture Overview

Efficiency

Gated DeltaNet shows exceptional performance in terms of training throughput compared to models like Mamba2 and Samba:

Language Modeling and Reasoning

Our model outperforms competitors of various types(e.g. Transformer, RNN, hybrid) in terms of perplexity and zero-shot accuracy on reasoning benchmarks:

Long-context

Gated DeltaNet also achieves favorable perplexity scores on long-context benchmarks:

🚀 Getting Started

Training Your Model

Launch your training with our streamlined command:

python ../pretrain.py
--train_data_dir ${TRAIN_DATA}
--val_data_dir ${VALIDATION_DATA}
--output_root ${SAVE_DIR}
--exp_name ${NAME}
--model_name ${MODEL}
--train_config ${CONFIG}
--eval_iters ${EVAL_ITERS}
--learning_rate ${LR}
--micro_batch_size ${MICRO_BATCH_SIZE}

💡 Pro Tip: Add --interactive_job --debug for interactive debugging sessions!

Please see this slurm script for training the GatedDeltaNet_H1 model with 0.4B parameters on 15B tokens. The training requires 4 nodes and can be finished in approximately 4 hours. For this run, the validation loss and perplexitty curves (1x & 2x for lengh extrapolation) are expected as follows:

curves

📜 License

Copyright © 2025, NVIDIA Corporation. All rights reserved.

Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.

🙏 Acknowledgements

Built on the shoulders of giants:

⭐ Support Us

If you find this work useful, please consider:

Join us in pushing the boundaries of linear transformers! 🚀

Citation

If you find Gated DeltaNet to be useful for your work, please consider citing our paper:

@inproceedings{yang2025gated,
title={Gated Delta Networks: Improving Mamba2 with Delta Rule},
author={Songlin Yang and Jan Kautz and Ali Hatamizadeh},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=r8H7xhYPwz}
}

Star History

Stargazers repo roster for @NVlabs/GatedDeltaNet

Star History Chart