RepoByteDance (Doubao/Seed)ByteDance (Doubao/Seed)published Sep 28, 2025seen 5d

ByteDance-Seed/AHN

Python

Open original ↗

Captured source

source ↗
published Sep 28, 2025seen 5dcaptured 15hhttp 200method plain

ByteDance-Seed/AHN

Description: AHN: Artificial Hippocampus Networks for Efficient Long-Context Modeling

Language: Python

License: Apache-2.0

Stars: 178

Forks: 5

Open issues: 2

Created: 2025-09-28T03:19:00Z

Pushed: 2025-10-17T08:01:46Z

Default branch: main

Fork: no

Archived: no

README:

AHN: Artificial Hippocampus Networks for Efficient Long-Context Modeling

This repo provides training, inference, and evaluation instructions for the paper AHN "Artificial Hippocampus Networks for Efficient Long-Context Modeling".

:bell: News

  • :fire: [2025-10-08]: AHN released!.

Introduction

!AHN first figure

Figure 1: Artificial Hippocampus Networks (AHNs) transform lossless memory into fixed-size compressed representations for long-context modeling. Lossless memory (e.g., attention’s key-value (KV) cache) stores exact input information but grows with sequence length, making it inefficient for long sequences. In contrast, compressed memory (e.g., RNNs’ hidden state) maintains a constant size and offers fixed computational costs per input token, but this comes at the cost of information loss. To harness the benefits of both memory types, AHNs continually convert lossless memory outside the sliding attention window into compressed form. AHNs can be instantiated with any RNN-like architectures. The model then integrates both memory types to make predictions across long contexts.

!AHN method

Figure 2: (a) Illustration of the model augmented with Artificial Hippocampus Networks (AHNs). In this example, the sliding window length is 3. When the input sequence length is less than or equal to the window length, the model operates identically to a standard Transformer. For longer sequences, AHNs continually compress the token outside the window into a compact memory representation. The model then utilizes both the lossless information within window, and the compressed memory to generate the next token. (b) Self-distillation training framework of AHNs based on an open-weight LLM. During training, the base LLM's weights are frozen, and only the AHNs' parameters are trained.

Usage

Core Modules

├── src/ # Core AHN Implementation
│ └── ahn/
│ ├── rnn/ # AHN modules in RNN-style architecture
│ └── transformer/ # Transformer backbone (e.g., Qwen2.5/3) with AHN integration

Installation

Default environment: Python 3.11, CUDA 12.4, PyTorch 2.5.1+cu124

# Install dependencies and set up AHN for training

# 1. Clone the AHN repository and move into it
git clone https://github.com/ByteDance-Seed/AHN.git
cd AHN

# 2. Install required forked libraries
pip install "git+https://github.com/Seerkfang/flash-linear-attention.git@main#egg=flash-linear-attention"
pip install "git+https://github.com/Seerkfang/LLaMA-Factory.git@main#egg=llamafactory"
# (Optional) Install the forked Mamba version if you plan to use AHN-Mamba2
# MAMBA_FORCE_BUILD=TRUE pip install "git+https://github.com/yuweihao/mamba.git"

# 3. Install AHN in editable mode with training extras
pip install -e ".[train]"

Model Zoo

| base model | AHN module | #params | checkpoint (AHN only) | |:---:|:---:| :---:|:---:| | Qwen2.5-3B-Instruct | Mamba2 | 11.9M | 🤗model | | Qwen2.5-3B-Instruct | DeltaNet | 11.8M | 🤗model | | Qwen2.5-3B-Instruct | GatedDeltaNet | 13.0M | 🤗model | | Qwen2.5-7B-Instruct | Mamba2 | 18.6M | 🤗model | | Qwen2.5-7B-Instruct | DeltaNet | 18.5M | 🤗model | | Qwen2.5-7B-Instruct | GatedDeltaNet | 21.3M | 🤗model | | Qwen2.5-14B-Instruct | Mamba2 | 51.4M | 🤗model | | Qwen2.5-14B-Instruct | DeltaNet | 51.1M | 🤗model | | Qwen2.5-14B-Instruct | GatedDeltaNet | 61.0M | 🤗model |

Inference

1. Merge base model and AHN weights. Example: Qwen2.5-3B-Instruct with GatedDeltaNet

# Base model (repo_id or local path)
BASE_MODEL=Qwen/Qwen2.5-3B-Instruct

# AHN-only weights (repo_id or local path)
AHN_PATH=ByteDance-Seed/AHN-GDN-for-Qwen-2.5-Instruct-3B

# Output directory for the merged model
MERGED_MODEL_PATH=./merged_ckpt/Qwen-2.5-Instruct-3B-AHN-GDN

python ./examples/scripts/utils/merge_weights.py \
--base-model $BASE_MODEL \
--ahn-path $AHN_PATH \
--output-path $MERGED_MODEL_PATH

This produces a merged model directory at MERGED_MODEL_PATH with both base and AHN parameters.

2. Run inference on a single GPU (example: GPU 0)

PROMPT="When was the concept of AI introduced?"
CUDA_VISIBLE_DEVICES=0 python ./examples/scripts/inference.py \
--model $MERGED_MODEL_PATH \
--prompt "$PROMPT"

Training

Example: Training Qwen2.5-7B-Instruct with GatedDeltaNet as the AHN module on ChatQA2, using a default global batch size of 128 across 32 GPUs (~700 steps):

bash ./examples/scripts/train_qwen2.5_3b_ahn_gdn.sh

Key arguments:

  • --model_name_or_path
  • Choose base model
  • Examples: Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B-Instruct
  • --dataset
  • Specify dataset for training
  • Default: chatqa2
  • --loss_type
  • Specify the loss function
  • Options: kl, ce
  • --ahn_implementation
  • Select the AHN memory module
  • Allowed values: GatedDeltaNet, DeltaNet, Mamba2
  • --ahn_position
  • Control where AHN modules are inserted
  • Options: prefix, random
  • --sliding_window_type
  • Define the sliding window strategy
  • Options: fixed, random
  • --save_ahn_only
  • Whether or not to save AHN only
  • Options: True, False
  • --filter_len
  • Filter training data by sequence length (e.g., keep samples longer than the length, 288)
  • --cutoff_len
  • Set the cutoff sequence length (e.g., 24576)

Debugging

Quick debug on a single GPU (GPU 0):

Excerpt shown — open the source for the full document.

Notability

notability 4.0/10

New repo, modest traction, ByteDance background.