amazon-science/dualkv-flash-attn-for-rl
Python
Captured source
source ↗amazon-science/dualkv-flash-attn-for-rl
Description: Implementation of DualKV: Shared-Prompt Flash Attention for Efficient RL Training with Large Rollouts and Long Contexts
Language: Python
License: NOASSERTION
Stars: 2
Forks: 3
Open issues: 2
Created: 2026-05-27T17:38:58Z
Pushed: 2026-06-05T18:11:44Z
Default branch: main
Fork: no
Archived: no
README:
DualKV: Shared-Prompt Flash-Attention for RL Training
Code release for *"DualKV: Shared-Prompt Flash-Attention Kernels for Efficient Policy Updates in RL Training"*.
DualKV deduplicates shared prompts in GRPO/DAPO training — instead of computing attention over N*(P+R) tokens, it computes over P + N*R, yielding up to 6x kernel speedup and 2x end-to-end throughput on long-context RL workloads. This release includes the custom flash-attention kernels, veRL integration (with Ulysses Sequence Parallelism support), and scripts to reproduce all paper experiments.
Repository Structure
├── flash-attention/ # FlashAttention-2 (commit 41b2ef6) with DualKV kernels applied ├── verl/ # veRL v0.7.0 with DualKV integration applied ├── experiments/ # Benchmarks, training scripts, reward functions ├── LICENSE # CC-BY-NC-4.0 └── THIRD_PARTY_LICENSES
Key implementation files:
- Forward kernel:
flash-attention/csrc/flash_attn/src/flash_fwd_kernel_dualkv_training.h - Backward kernel:
flash-attention/csrc/flash_attn/src/flash_bwd_kernel_dualkv_training.h - Python interface:
flash-attention/flash_attn/flash_attn_interface.py(search fordualkv) - veRL actor integration:
verl/verl/workers/actor/dp_actor.py(search for_dualkv) - Attention monkey-patch + SP:
verl/verl/models/transformers/monkey_patch.py(DualKV + Ulysses all-to-all) - SP correctness test:
experiments/test_dualkv_sp_correctness.py
Hardware Requirements
| Experiment | GPUs | |------------|------| | Kernel benchmarks (Table 1, Table 2) | 1x H100-80GB | | Qwen3-8B end-to-end (Table 5, Table 8) | 8x H100-80GB | | Qwen3-14B end-to-end | 8x H100-80GB | | DAPO end-to-end (Table 7) | 8x H100-80GB | | Qwen3-30B-A3B multi-node (Table 3) | 16x H100-80GB (2 nodes) | | Memory scaling sweep | 1x H100-80GB |
Software Environment
| Package | Version | |---------|---------| | Python | 3.12 | | PyTorch | 2.9.0+cu128 | | CUDA | 12.8 | | flash-attn | 2.8.4 (included, with DualKV) | | veRL | 0.7.0 (included, with DualKV) | | vLLM | 0.12.0 | | Ray | 2.55.0 | | Transformers | 4.57.6 |
Setup
git clone dualkv && cd dualkv python3 -m venv .venv && source .venv/bin/activate pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu128
Install Flash Attention (with DualKV kernels)
cd flash-attention pip install ninja numpy packaging git clone --depth 1 https://github.com/NVIDIA/cutlass.git csrc/cutlass pip install -e . --no-build-isolation cd ..
Verify: python -c "from flash_attn import flash_attn_dualkv_varlen_func; print('OK')"
Install veRL (with DualKV integration)
cd verl pip install -e . cd ..
(Optional) Flash Attention 3
Only needed to reproduce FA3 baseline rows in Table 5 and Table 7:
git clone https://github.com/Dao-AILab/flash-attention.git /tmp/flash-attention-3 cd /tmp/flash-attention-3 && git checkout v3.0.0 && cd hopper && pip install -e .
Verify: python -c "from flash_attn_interface import flash_attn_func; print('FA3 OK')"
(Optional) Prefix Grouper
Only needed to reproduce the Prefix Grouper baseline in Table 2:
pip install git+https://github.com/CASIA-IVA-Lab/PrefixGrouper.git
Remaining Dependencies
pip install vllm==0.12.0 ray==2.55.0 wandb pandas pyarrow
Models and Data
WORKDIR=/path/to/your/workdir
# Models
huggingface-cli download Qwen/Qwen3-8B --local-dir ${WORKDIR}/models/Qwen3-8B
huggingface-cli download Qwen/Qwen3-14B --local-dir ${WORKDIR}/models/Qwen3-14B
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir ${WORKDIR}/models/Qwen3-30B-A3B
# Data
python experiments/preprocess_longreason.py --local_save_dir ${WORKDIR}/data/longreason
python experiments/preprocess_quality.py --local_save_dir ${WORKDIR}/data/qualityReproducing Experiments
Set environment before running any script:
export WORKDIR=/path/to/your/workdir export WANDB_API_KEY=your_key # optional, scripts fall back to console logging
Notation: mb = micro-batch size (prompt groups per training step), P = prompt length, N = number of responses per prompt, R = response length, SP = Ulysses sequence parallelism degree, DP = data parallelism degree, FA2/FA3 = FlashAttention-2/3.
Table 1: Kernel-Level Benchmarks (1x H100 or A100)
Isolated DualKV vs FA2 attention kernel timing (fwd + bwd), fp16.
CUDA_VISIBLE_DEVICES=0 python experiments/reproduce_table1.py
Expected output (H100-80GB):
N P | FA2 fwd FA2 bwd FA2 f+b | DK fwd DK bwd DK f+b | fwd bwd f+b 28 4096 | 49.4 165.8 215.3 | 34.4 98.7 133.1 | 1.44x 1.68x 1.62x 28 16384 | 425.0 1325.8 1750.8 | 120.1 347.6 467.7 | 3.54x 3.81x 3.74x 16 32768 | 857.7 2645.8 3503.4 | 174.5 504.9 679.4 | 4.91x 5.24x 5.16x 28 32768 | 1500.9 4609.0 6109.9 | 259.8 758.4 1018.2 | 5.78x 6.08x 6.00x 16 65536 | OOM OOM OOM | 454.2 1277.7 1731.8 | inf inf inf
Table 2: Single-Layer DualKV vs Prefix Grouper vs FA2 (1x H100)
Single Qwen3-8B decoder layer fwd+bwd with realistic response lengths. Prefix Grouper is self-implemented (no external package needed).
CUDA_VISIBLE_DEVICES=0 python experiments/reproduce_table2.py
Paper Table 2 reports configs: (P=5K, mb=32), (8K, 16), (16K, 8), (32K, 4). The script sweeps the full P x mb grid and marks paper configs with *.
Single-Step Full-Model Benchmark (8x H100)
torchrun --standalone --nproc-per-node 8 experiments/benchmark_qwen3_single_step.py \
--model ${WORKDIR}/models/Qwen3-8B --path bothTable 5: End-to-End GRPO (Qwen3-8B, 8x H100)
| Config | Script | |--------|--------| | FA2 mb=4 (baseline) | bash experiments/run_qwen3_8b_longreason_fa2.sh | | FA3 mb=4 | bash experiments/run_qwen3_8b_longreason_fa3.sh | | DualKV mb=4 | bash experiments/run_qwen3_8b_longreason_dualkv_mb4.sh | | DualKV mb=8 | bash experiments/run_qwen3_8b_longreason_dualkv_mb8.sh |
Table 7: End-to-End DAPO (Qwen3-8B, 8x H100)
| Config | Script | |--------|--------| | FA2 mb=4 | `bash…
Excerpt shown — open the source for the full document.
Notability
notability 3.0/10Low-star research repo from Amazon.