by Eric Alcaide
This blogpost describes how we accelerated Gated DeltaNet by 1.15x in the forward pass following Simon Veitner’s blogpost. Here’s our PR. Once it’s merged, practitioners will get the speedup by upgrading FLA version.
Figure 1: Execution time comparison of FLA (commit: f52529e) and our improved version, layered on top of FLA’s own fused kernels. Benchmarked on NVIDIA H100.
Table of contents
Gated DeltaNet (ICLR 2025) is a linear attention mechanism that combines the delta rule with scalar gating in the state transition. DeltaNet was proposed as an alternative to transformers in 2021 (Schlag et al., ICML 2021) and scaled to hardware efficient training in 2024 (Yang et al., NeurIPS 2024). Gated DeltaNet introduces an additional decay, proven to be effective in sequence models (RetNet, RWKV6, Mamba2). The state update is:
\[S_t = \alpha_t \left(I - \beta_t \, k_t k_t^\top \right) S_{t-1} + \beta_t \, v_t k_t^\top \quad \in \mathbb{R}^{D_k \times D_v}\]where $\alpha_t = \exp(g_t)$ is a scalar decay gate and $\beta_t$ is a scalar learning rate. The Householder-like term $(I - \beta_t \, k_t k_t^\top)$ performs a rank-1 correction to the memory state, and the gate controls how much of the old state to retain. The output is then simply $o_t = q_t^\top S_t \in \mathbb{R}^{D_v}$.
GDN has quickly become a practical building block for production LLMs from big labs. Alibaba’s Qwen3Next and Qwen3.5 family uses Gated DeltaNet in 75% of its layers (3:1 hybrid with standard attention). A reference implementation of a high performance, chunkwise-parallel kernel lives in FLA repo, which provides highly optimized Triton kernels for the chunked forward and backward passes of GDN and other alternatives to attention.
This blogpost describes the implementation of the optimization proposed by Simon Veitner in this blog post (we refer the reader there for details on the algorithm) which was also noted independently in Comba.
In the chunked algorithm, two coupling matrices appear:
FLA computes N directly (applying BT^2 exp operations inside the coupling matrix) and then solves N^{-1}.
The key observation is that $N$ is a similarity transform of $M$: $N = GMG^{-1}$ where $G = \mathrm{diag}(\exp(g^{\mathrm{cum}}_1), \ldots, \exp(g^{\mathrm{cum}}_C))$. Therefore: $N^{-1} = GM^{-1}G^{-1}$.
This means we can:
M (no gating – skip BT^2 exp ops)M^{-1} (same cost as before)G and G^{-1} as diagonal scaling in the WY step (only 2*BT exp ops)Net savings: BT^2 - 2*BT exp operations per chunk. With BT=64, that’s 3968 fewer exp ops per chunk.
The math trick alone doesn’t beat FLA. The original implementation used separate kernels for kkt, solve, and WY – and the HBM round-trip for the intermediate A matrix wiped out the exp savings.
FLA fuses (kkt+solve) into a single kernel (the A matrix stays in registers, never hits HBM). Our implementation layers the trick on top of FLA’s existing fused kernels rather than reimplementing them:
1. (kkt+solve) + tricked WY (training path) – we call FLA’s own fused (kkt+solve) kernel with g=None (ungated), then run our custom WY kernel that applies G/G_inv scaling. The solved A is written to HBM for the backward pass.
2. (kkt+solve+WY) aka “fusemaxxed” (inference path) – all three steps in a single kernel. The A matrix is computed, solved, AND consumed for the WY computation entirely in registers. Zero HBM traffic for A. Used only when torch.is_grad_enabled() == False.
mem_efficient modeWhen gradients are enabled, the training path uses (kkt+solve) + tricked WY and saves A to HBM for the backward pass. Additionally, when mem_efficient=False (default) and T > 2048, the w and u tensors are cached during the forward pass so the backward skips their recomputation. Set mem_efficient=True to trade compute for memory on long sequences.
All benchmarks on NVIDIA H100, B=1, H * Dh = 2048, bf16. FLA baseline is flash-linear-attention (commit: f52529e). Geometric mean speedups computed across all sequence lengths (1K-128K).
(kkt+solve+WY) fusemaxxed is 1.26-1.40x faster at short sequences (T <= 4K) where kernel launch overhead and HBM traffic dominate. Geo mean: 1.11x (Dh=128), 1.12x (Dh=256).
(kkt+solve) + tricked WY (training path) is 1.14x faster at long sequences (T >= 16K) where the trick eliminates BT^2 exp2 ops per chunk from the kkt computation. At short sequences, it runs at parity with FLA since the kkt+solve kernel is FLA’s own. Geo mean: 1.13x (Dh=128), 1.05x (Dh=256).
Auto-dispatch (inference): automatically picks fusemaxxed for T <= 8K and the fused path for T > 8K. Never regresses vs FLA. Geo mean: 1.21x (Dh=128), 1.14x (Dh=256).
Forward + Backward (with mem_efficient=False): When w and u are cached during the forward pass (for T >= 2048), the backward skips their recomputation. Geo mean: 1.02x (Dh=128), 1.02x (Dh=256).
We expect this to be relevant for those running inference with OSS LLMs which incorporate GDN in up to 75% of their layers.
@misc{alcaide2026accelerating,
title = {Accelerating GatedDeltaNet Inference by 1.15x},
author = {Eric Alcaide},
month = {March},
year = {2026},
url = {https://hypnopump.github.io/post.html?slug=accelerating-gdn-inference}
}
@misc{veitner2026speedup,
title = {Simple Math to Speed Up GDN Prefill},
author = {Simon Veitner},
month = {March},
year = {2026},
url = {https://veitner.bearblog.dev/simple-math-to-speed-up-gdn-prefill/}
}