nebius/kvax
Python
Captured source
source ↗nebius/kvax
Description: A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism.
Language: Python
License: Apache-2.0
Stars: 167
Forks: 9
Open issues: 3
Created: 2025-01-10T13:34:47Z
Pushed: 2025-11-11T17:22:27Z
Default branch: main
Fork: no
Archived: no
README:
Kvax: fast and easy-to-use flash attention implementation for JAX
Kvax is an open-source library offering fast and efficient attention operations for the JAX framework. Built with Flash Attention 2 algorithms implemented in the Triton language, it is optimised for high-performance attention computation with document masks and supports context parallelism. Kvax is designed to perform exceptionally well in distributed training scenarios on long sequences using FSDP/HSDP sharding.
More technical details in our blogpost: https://nebius.com/blog/posts/kvax-open-source-flash-attention-for-jax
Table of Contents:
- [Key Concepts of Kvax Implementation](#key-concepts-of-kvax-implementation)
- [Kvax Features](#kvax-features)
- [Kvax Results](#kvax-results)
- [How to install](#how-to-install)
- [How to use](#how-to-use)
- [Package Description](#package-description)
- [Benchmarks](#benchmarks)
- [Limitations](#limitations)
- [Contributing](#contributing)
- [Citation](#citation)
- [License](#license)
Key Concepts of Kvax Implementation
Document Mask Optimisation
When training transformer models on long sequences, a significant amount of compute is spent on attention operations due to the quadratic complexity of the attention algorithm. Flash Attention algorithm offers hardware-specific optimisations to significantly reduce latency and memory requirements for these operations.
During training on long sequences, dense packing is often used to maximise compute resource utilisation. In this approach, multiple data points are packed into a single sequence while avoiding cross-sequence attention contamination. The main idea is to calculate only the blocks of attention weights that include tokens which should attend to each other while skipping other blocks. Various methods can efficiently handle this, with PyTorch's FlexAttention being one example. Kvax takes a similar approach to achieve high performance in these scenarios.
Context Parallelism
Using long sequences during training can also lead to high GPU memory consumption for storing layer activations. Context parallelism helps solve this problem, speeding up the computations and reducing memory required for layer activations.
There are several approaches to implementing context parallelism for transformer architectures, such as RingAttention and all-gather based method. The all-gather based method, described in the Llama 3 training paper, performs an all-gather on the key and value tensors, collecting tensors before attention computation due to their lower memory requirements enabled by GQA. This method is particularly well-suited for document masks, and Kvax leverages it in its implementation.
Kvax Features
- Block-wise Attention Masks: Like FlexAttention, our implementation builds the attention mask once per forward-backward pass, reusing it across layers. Our high-performance Triton kernel builds this mask blockwise, and does not require
O(seq_len^2)GPU memory.
- Optimised Memory Storage: Kvax stores attention masks in block-wise format, requiring
3 * 4 * batch_size * seq_len // block_size * 4 bytes(block_size is typically 64 or 128).
- Skipping Pad Tokens: Kvax skips blocks consisting entirely of padding tokens. See the "How to Use" section for details on defining padding tokens.
- Context Parallelism: Kvax balances tokens across GPUs to ensure equal attention operation loads, accounting for causal masks. This feature is described in Llama 3 training paper and fully integrates with document mask optimisations.
Kvax Results


More details on Kvax benchmarking and its results can be found in the blogpost.
How to install
Install the latest stable release from pip:
pip install kvax
Note: The automatically installed versions of Triton and JAX-Triton might not be compatible. If you encounter an error while running the provided benchmarks, please ensure that you install compatible versions manually. For benchmarking, we used `triton==3.1` and `jax-triton==0.2.0`.
How to use
First, ensure that the position of every padding token is marked with PADDING_SEGMENT_ID in the query_segment_ids and kv_segment_ids tensors:
from kvax.utils import PADDING_SEGMENT_ID # In this example, the sequence length is 8, and there are 2 padding tokens. pad_token_id = 128001 input_ids = [6151, 0, 52043, 710, 374, 1618, pad_token_id, pad_token_id] query_segment_ids = [0, 0, 0, 0, 0, 0, PADDING_SEGMENT_ID, PADDING_SEGMENT_ID] kv_segment_ids = [0, 0, 0, 0, 0, 0, PADDING_SEGMENT_ID, PADDING_SEGMENT_ID]
Then, kvax functions can be used in the transformer code:
import flax.linen as nn from kvax.ops import ( create_attention_mask, flash_attention, ) from kvax.utils import ( attention_specs, permute_tokens_context_parallelism, unpermute_tokens_context_parallelism, ) class AttentionLayer(nn.Module): def __call__( self, embedding, query_positions, query_segment_ids, kv_positions, kv_segment_ids, attn_mask, ): query, key, value = ... scale = ... # Call the Flash Attention op attn_out = flash_attention( query=query, key=key, value=value, query_positions=positions, query_segment_ids=segment_ids, kv_positions=kv_positions, kv_segment_ids=kv_segment_ids, mask=attn_mask, assume_sequential_positions=self.config.assume_sequential_positions, scale=scale, # Mesh is defined as a global context # mesh=mesh, ) out = ... return out class Transformer(nn.Module): ... def setup(self): self.attn_layers = [AttentionLayer(...) for _ in…
Excerpt shown — open the source for the full document.
Notability
notability 5.0/10New repo with moderate stars