basetenlabs/TorchSpec
forked from lightseekorg/TorchSpec
Captured source
source ↗basetenlabs/TorchSpec
Description: A PyTorch native library for training speculative decoding models
License: MIT
Stars: 0
Forks: 0
Open issues: 0
Created: 2026-05-19T16:37:18Z
Pushed: 2026-05-19T16:39:10Z
Default branch: main
Fork: yes
Parent repository: lightseekorg/TorchSpec
Archived: no
README:
TorchSpec
TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via Mooncake store, allowing each side to scale independently.
TorchSpec currently includes training flows and examples for:
- Kimi-K2.5
- MiniMax-M2.5
- Qwen3-Coder-Next
🤗 Released Models
Draft models trained with TorchSpec, available on the LightSeek Foundation Hugging Face organization:
- lightseekorg/kimi-k2.5-eagle3
- lightseekorg/kimi-k2.5-eagle3-mla
- lightseekorg/kimi-k2.6-eagle3
- lightseekorg/kimi-k2.6-eagle3-mla
🚀 Blogs
- PyTorch blog: TorchSpec: Speculative Decoding Training at Scale
- Release blog: TorchSpec: Speculative Decoding Training at Scale
Table of Contents
- [Architecture Overview](#architecture-overview)
- [Inference Backend Support](#inference-backend-support)
- [Quick Start](#quick-start)
- [Setup](#setup)
- [Examples](#examples)
- [Training Modes](#training-modes)
- [Checkpoint Conversion](#checkpoint-conversion)
- [Metrics Reporting](#metrics-reporting)
- [Troubleshooting](#troubleshooting)
Architecture Overview
TorchSpec is built around a disaggregated training pipeline:
- Inference engines generate target-model hidden states with inference engines.
- Mooncake store transfers tensors between inference and training without materializing them on disk.
- Training workers consume streamed hidden states to train speculative decoding draft models.
This separation keeps the training side focused on optimization while letting the inference side scale for hidden-state generation throughput.
Inference Backend Support
TorchSpec streams hidden states from inference engines into training workers.
| Backend | Support Tier | Status | |---------|--------------|--------| | vLLM | First-class | Available | | TokenSpeed | First-class | In progress | | SGLang | Best community effort | Available | | HuggingFace Transformers | Best community effort | Available |
Quick Start
Train an Eagle3 draft model for Qwen3-8B on a single node with 4 GPUs (2 for training and 2 for inference):
./examples/qwen3-8b-single-node/run.sh
Override config values directly from the CLI:
./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500
Setup
Quick Setup
# Install with vLLM ./tools/build_conda.sh 1 vllm micromamba activate torchspec # Or install with SGLang ./tools/build_conda.sh micromamba activate torchspec
To install into your current environment instead:
./tools/build_conda.sh current sglang # or 'vllm' or 'both'
Optional: install Flash Attention support:
pip install -e ".[fa]"
Backend-Specific Usage
vLLM
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml
SGLang
./examples/qwen3-8b-single-node/run.sh
TorchSpec uses vLLM's Worker Extension mechanism to hook into the model forward pass and capture hidden states directly inside worker processes, which avoids RPC serialization overhead during extraction. For SGLang, TorchSpec applies a patch to the existing codebase to enable hidden-state extraction.
Examples
| Example | Backend | Model | |---------|---------|-------| | [hf-quickstart](examples/hf-quickstart/) | HuggingFace | Qwen3-8B | | [qwen3-8b-single-node](examples/qwen3-8b-single-node/) | Inference engine | Qwen3-8B | | [kimi-k25-2node-h200](examples/kimi-k25-2node-h200/) | Inference engine | Kimi-K2.5 | | [kimi-k25-3node-h100](examples/kimi-k25-3node-h100/) | Inference engine | Kimi-K2.5 | | [minimax-m25-5node-h200](examples/minimax-m25-5node-h200/) | Inference engine | MiniMax-M2.5 |
See [examples/README.md](examples/README.md) for more details about each example.
Training Modes
Resume vs. Continual Training
Both modes use training.load_path, but they restore different states:
| Goal | training.load_path | training.continual_training | What gets restored | |------|----------------------|-------------------------------|--------------------| | Resume an interrupted run | Required | false (default) | Model, optimizer, LR scheduler, RNG, and step metadata | | Start a new run from existing weights | Required | true | Model weights only |
Resume the same run:
training: load_path: /path/to/old_run/checkpoints output_dir: /path/to/old_run
Start a new run from existing weights:
training: load_path: /path/to/old_run/checkpoints continual_training: true learning_rate: 1e-5 warmup_ratio: 0.01 num_epochs: 1 output_dir: /path/to/new_run
Checkpoint Conversion
Convert an FSDP checkpoint to HuggingFace format:
python tools/convert_to_hf.py --input-dir ./outputs/my_experiment/iter_0010000/
Vocabulary pruning, which reduces the draft model lm_head to a smaller token set and emits d2t and t2d mappings, can be applied either during training or at conversion time.
- Pre-pruning: set
draft_vocab_sizein your training config. The checkpoint already contains the prunedlm_headandd2t/t2dbuffers, so the basic conversion command is enough. - Post-pruning: train with the full vocabulary, then pass
--prune-vocabat conversion time together with a representative dataset to compute token frequencies.
python…
Excerpt shown — open the source for the full document.
Notability
notability 1.0/10Routine fork, no traction.