togethercomputer/CREST
Python
Captured source
source ↗togethercomputer/CREST
Description: CREST is a training-free test-time steering framework that discovers cognitive heads via simple offline calibration and then rotates activations during decoding to guide the model’s reasoning—preserving norms to avoid per-model hyperparameter tuning. This improves accuracy and reduces tokens across models and datasets.
Language: Python
License: Apache-2.0
Stars: 10
Forks: 1
Open issues: 0
Created: 2025-10-17T13:53:21Z
Pushed: 2025-12-31T05:06:23Z
Default branch: main
Fork: no
Archived: no
README:
CREST: Cognitive REasoning Steering at Test‑time
TL; DR
CREST is a training-free test-time steering framework that discovers cognitive heads via simple offline calibration and then rotates activations during decoding to guide the model’s reasoning—preserving norms to avoid per-model hyperparameter tuning. This improves accuracy and reduces tokens across models and datasets.
What is CREST?
CREST (Cognitive REasoning Steering at Test-time) identifies attention heads whose activations are predictive of different reasoning modes (“cognitive heads”), then steers those heads at inference to suppress inefficient trajectories and encourage effective reasoning—without further training.
- Token savings with accuracy gains. e.g., R1-7B on MATH500: 92.4% with 34% fewer tokens; R1-1.5B on AMC23: 37.6% token reduction at higher accuracy.
- Generalizes across models/datasets (DeepSeek-R1 1.5B/7B/32B, Qwen3-4B/30B, GPT-OSS-20B; MATH500, AIME, AMC23, GPQA-D, LiveCodeBench, Calendar Planning).
- Head ratio “gold default.” Steering about the top ~38% heads (ranked by linear-probe accuracy) balances accuracy and token reduction; adopted as the default.
Usage
Install
Recommend use vllm docker:
inside the docker:
cd CREST/probing/omni_math_rule/evaluation pip install evaluate cd latex2sympy pip install -e . cd .. pip install -r requirements.txt cd ../../ pip install lighteval pip install datasets==3.5.0 pip install emoji
otherwise:
cd CREST pip install requirements.txt cd probing/omni_math_rule/evaluation pip install evaluate cd latex2sympy pip install -e . cd .. pip install -r requirements.txt cd ../../ pip install lighteval pip install datasets==3.5.0 pip install emoji
### Running Baseline Experiments
Run baseline without steering
bash script/baseline.sh
This script: - Sets `STEERING=False` - Evaluates models across multiple datasets - Saves results in `results/SteeringFalse/` directory ### Running Steering Experiments
Run with steering enabled
bash script/ours.sh
This script: - Sets `STEERING=True` - Configures steering parameters: - `steering_number=512`: Number of top steering vectors to use - `steering_coef=-4`: Steering coefficient (negative for inhibition) - `steering_mode=after_o_proj_norm_threshold`: Steering application mode - Saves results in `results/SteeringTrue_numb512_coef-4_mode[mode]/` directory - **NOW STEERING MUST SET ```--enforce_eager```, CUDA GRAPH AND TORCH COMPILE ARE NOT SUPPORTED YET** #### Steering Vector Zoo The Steering Vector Zoo contains pre-trained steering vectors that can be applied to modify model behavior during inference. These vectors are learned through probing techniques and stored as PyTorch (.pt) files.
probing/results/ ├── [DATASET]/ # Training dataset (e.g., MATH_train) │ └── [MODEL]/ # Model name (e.g., Qwen3-30B-A3B-Thinking-2507) │ └── template-t0-n1-[SIZE]/ # Template and size configuration │ └── hidden[MODE]/ # Steering mode directory | └── mix_others_low_rank_1000/ # Training methods │ └── probe_best.pt # Steering vector file
The system automatically selects the top-k most accurate steering vectors based on: 1. **Probe Accuracy**: Classification performance on the validation set 2. **Layer-Head Coverage**: Distribution across different attention layers and heads 3. **Steering Coefficient**: Magnitude of influence (configurable via `--steering_coef`) #### Manual Execution
python main_vllm.py \ --model_name_or_path "Qwen/Qwen3-30B-A3B-Thinking-2507" \ --dataset "aime25" \ --save_dir "results/test/" \ --use_chat_format \ --temperature 0.6 \ --max_tokens 32768 \ --steering True \ --steering_vector_path "/path/to/steering/vectors/" \ --steering_number 512 \ --steering_coef -4 \ --steering_mode "after_o_proj_norm_threshold"
#### Required Environment Variables (for steering)
export STEERING=True export STEERING_VECTOR_PATH="/path/to/steering/vectors/" export STEERING_NUMBER=512 export STEERING_COEF=-4 export STEERING_MODE="after_o_proj_norm_threshold" export MODEL_NAME_OR_PATH="Qwen/Qwen3-30B-A3B-Thinking-2507"
## Repository overview The implementation consists of three main components: ### 1. Main Inference Engine (`main_vllm.py`) - **Purpose**: Core inference script supporting multiple datasets and models - **Features**: - Support for multiple datasets (MATH, GSM, AIME, GPQA, LiveCodeBench, etc.) - Batch processing with vLLM backend - Configurable steering parameters - Multi-model support with automatic model detection ### 2. Steering Implementation (`probing/modeling_utils/vllm/`) Contains model-specific monkey patches for different architectures: - **`qwen2/monkey_patch.py`**: DeepSeek-R1-Distill-Qwen models - **`qwen3/monkey_patch.py`**: Qwen3-4B-Thinking models - **`qwen3_moe/monkey_patch.py`**: Qwen3-30B-A3B-Thinking models - **`gpt_oss/monkey_patch.py`**: GPT-OSS models Each monkey patch implements: - Attention mechanism modifications - Layer-wise steering vector application - Dynamic steering flag management - Multiple steering modes ### 3. Evaluation System (`probing/get_omni_results.py`) - **Purpose**: Mathematical reasoning evaluation using OmniMath rules - **Features**: - Parallel evaluation with timeout handling - Multiple prediction aggregation - Accuracy computation and result saving ## Steering Modes The system supports four distinct steering modes, each applying steering vectors at different points in the attention mechanism: ### 1. `before_o_proj` - **Application Point**: Before the output projection in attention - **Mechanism**: Modifies attention output before linear transformation - **Use Case**: Early intervention in attention computation ### 2. `after_o_proj` - **Application Point**: After the output projection in attention -…
Excerpt shown — open the source for the full document.
Notability
notability 2.0/10Low stars, minor repo