novitalabs/FlashMLA
forked from deepseek-ai/FlashMLA
Captured source
source ↗novitalabs/FlashMLA
Description: FlashMLA: Efficient Multi-head Latent Attention Kernels
License: MIT
Stars: 2
Forks: 0
Open issues: 0
Created: 2025-10-27T08:39:58Z
Pushed: 2025-10-29T01:44:14Z
Default branch: main
Fork: yes
Parent repository: deepseek-ai/FlashMLA
Archived: no
README:
FlashMLA
Introduction
FlashMLA is DeepSeek's library of optimized attention kernels, powering the DeepSeek-V3 and DeepSeek-V3.2-Exp models. This repository contains the following implementations:
Sparse Attention Kernels
*These kernels power DeepSeek Sparse Attention (DSA), as introduced in this paper.*
- Token-level sparse attention for the prefill stage
- Token-level sparse attention for the decoding stage, with FP8 KV cache
Dense Attention Kernels
- Dense attention for the prefill stage
- Dense attention for the decoding stage
News
- 2025.09.29 Release of Sparse Attention Kernels: With the launch of DeepSeek-V3.2, we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md).
- 2025.08.01 Kernels for MHA on SM100: Thanks to NVIDIA's PR for MHA forward / backward kernels on SM100!
- 2025.04.22 Deep-Dive Blog: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md).
- 2025.04.22 Performance Update: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀
Performance
Test & benchmark MLA decoding (Sparse & Dense):
python tests/test_flash_mla_decoding.py
The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
Test & benchmark MHA prefill (Dense):
python tests/test_fmha_sm100.py
It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.
Test & benchmark MLA prefill (Sparse):
python tests/test_flash_mla_prefill.py
It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.
Requirements
- SM90 / SM100 (See the support matrix below)
- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)
- PyTorch 2.0 and above
Support matrix:
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | :---: | :---: | :---: | :---: | | Dense Decoding | SM90 | MQA | BF16 | | Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] | | Dense Prefill | SM100 | MHA | | | Sparse Prefill | SM90 & SM100 | MQA | |
[1]: For more details on using FP8 KV cache, see documents below.
[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. head_dim_k = 576 with head_dim_v = 512), while MHA stands for Multi-Head Attention mode (i.e. head_dim_k = 192 / 128 with head_dim_v = 128). For a detailed explanation of these modes, please refer to the appendix of DeepSeek V3.2's Paper.
Installation
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla cd flash-mla git submodule update --init --recursive pip install -v .
Usage
MLA Decoding
To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:
from flash_mla import get_mla_metadata, flash_mla_with_kvcache tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk, ) for i in range(num_layers): ... o_i, lse_i = flash_mla_with_kvcache( q_i, kvcache_i, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, is_causal, is_fp8_kvcache, indices, ) ...
Where
s_qis the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.h_kvis the number of key-value heads.h_qis the number of query heads.
FP8 KV Cache: If is_fp8_kvcache is set to True, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
- First 512 bytes: The "quantized NoPE" part, containing 512
float8_e4m3values. - Next 16 bytes: Scale factors, containing 4
float32values. The firstfloat32is the scale for the first 128float8_e4m3values, the second for the next 128, and so on. - Last 128 bytes: The "RoPE" part, containing 64
bfloat16values. This part is not quantized for accuracy.
See tests/quant.py for quantization and dequantization details.
Sparse Attention (`indices` tensor): The indices tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.
- Shape:
indicesshould be a 3D tensor of shape(batch_size, seq_len_q, topk). - Format:
indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block), wheretis the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into…
Excerpt shown — open the source for the full document.
Notability
notability 1.0/10Trivial fork with 2 stars