RepoDatabricks (DBRX)Databricks (DBRX)published Feb 23, 2026seen 5d

databricks/flashoptim

Python

Open original ↗

Captured source

source ↗
published Feb 23, 2026seen 5dcaptured 11hhttp 200method plain

databricks/flashoptim

Language: Python

License: Apache-2.0

Stars: 252

Forks: 10

Open issues: 4

Created: 2026-02-23T21:45:25Z

Pushed: 2026-04-17T18:17:32Z

Default branch: main

Fork: no

Archived: no

README:

This is the official implementation of FlashOptim: Optimizers for Memory-Efficient Training

By Jose Javier Gonzalez Ortiz, Abhay Gupta, Christopher Rinard, and Davis Blalock.

TL;DR

FlashOptim is a library implementing drop-in replacements for PyTorch optimizers that substantially reduces training memory by shrinking the footprint of optimizer states, master weights, and gradients.

For example, for finetuning an 8B model, FlashOptim requires 35% less peak memory and produces checkpoints that are 57% smaller.

Despite operating in reduced precision, FlashOptim does not affect model convergence.

1. Quickstart

To get started you can install flashoptim:

$ pip install flashoptim

Once installed, you can import FlashSGD, FlashSGDW, FlashAdam, FlashAdamW and FlashLion, which follow the standard PyTorch optimizer API. For example, to use FlashAdamW:

import torch
from torch import nn

from flashoptim import FlashAdamW, cast_model

model = nn.Sequential(nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 10)).cuda()
# cast parameters to bf16
cast_model(model, dtype=torch.bfloat16)

# master_weight_bits=24 (default) means we have 24-bit parameter semantics
optimizer = FlashAdamW(model.parameters(), lr=1e-3)

x = torch.randn(32, 128, device="cuda", dtype=torch.bfloat16)
loss = model(x).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()

That's it! You are now training with 50% less per-parameter memory! For more details on the API and advanced features, keep reading.

2. Key Features

  • Memory Savings. By splitting the weight representation and quantizing the optimizer states, FlashOptim reduces per-parameter memory (e.g. 57% for Adam) and peak training memory without degrading convergence.
  • Fused Triton Kernels. All compression operations are fused into the update kernel, introducing no practical overhead.
  • Gradient Release. Optionally, parameters can be updated as soon as the gradients are computed, further reducing peak memory.
  • Compressed Checkpoints. Checkpoints can optionally be stored using quantized optimizer states, producing >50% space savings.
  • PyTorch API. The optimizers follow the standard torch.optim.Optimizer interface.

3. Installation

FlashOptim can be installed using pip or uv. Note that FlashOptim is only supported on Linux systems with NVIDIA CUDA GPUs.

# install stable version
pip install flashoptim

# install latest version from source
pip install git+https://github.com/databricks/flashoptim.git

# or install it locally in editable mode for development
git clone https://github.com/databricks/flashoptim.git
cd flashoptim
pip install -e .

4. Usage

> [!NOTE] > The first optimizer step will be slower than subsequent steps due to Triton kernel JIT compilation. This is a one-time cost per kernel configuration.

Specifying Precision

The master_weight_bits parameter controls the width of the master weights maintained by the optimizer. By default, master weights are 24-bit, narrower than fp32, which saves memory. When training in bf16/fp16, the downcasting is fused into the update kernel, so no separate cast step is needed:

from flashoptim import FlashAdamW

# Default: 24-bit master weights (bf16 param + 8-bit correction term)
optimizer = FlashAdamW(model.parameters(), lr=1e-3)

# 32-bit master weights (bf16 param + 16-bit correction term)
optimizer = FlashAdamW(model.parameters(), lr=1e-3, master_weight_bits=32)

# No master weight correction; parameters stay at native precision
optimizer = FlashAdamW(model.parameters(), lr=1e-3, master_weight_bits=None)

The exact behavior depends on the dtype of the parameters passed to the optimizer:

  • bf16/fp16 parameters: Optimizer states (moments) are quantized to 8-bit. The master_weight_bits setting controls master weight precision and fuses the downcasting into the update kernel:
  • master_weight_bits=24 (default): 8-bit correction terms for 24-bit master weights, narrower than fp32 while preserving convergence
  • master_weight_bits=32: 16-bit correction terms for full 32-bit master weight semantics
  • master_weight_bits=None: no master weight correction; optimizer states are still quantized, but parameters stay at their native precision
  • fp32 parameters: Optimizer states (moments) are quantized to 8-bit to reduce memory. Parameters are already full precision, so master_weight_bits is not applicable.

To cast a model's parameters and buffers to bf16, use the cast_model helper. By default, normalization layers with running statistics are kept in fp32 for training stability. Forward pre-hooks upcast inputs to fp32 modules automatically:

from flashoptim import cast_model

# Cast all parameters to bf16 (normalization layers kept in fp32 by default)
cast_model(model, dtype=torch.bfloat16)

# Terminal layers (e.g., lm_head) - kept fp32, output stays fp32
cast_model(model, dtype=torch.bfloat16, full_precision_layers=["lm_head", "*.head"])

# Middle layers - kept fp32 but output recast to bf16
cast_model(model, dtype=torch.bfloat16, full_precision_recast_layers=["target"])

# Module references work too
cast_model(model, full_precision_layers=[model.lm_head])

> [!NOTE] > Layer names are matched with fnmatch against the full dotted module name, so "head" matches a top-level model.head but not model.decoder.head. Use "*.head" for nested modules.

Weight Decay

FlashOptim follows PyTorch's convention of separating L2 regularization from decoupled weight decay via separate classes:

| Optimizer | Weight Decay Style | PyTorch Equivalent | |-----------|-------------------|-------------------| | FlashAdam | L2 regularization (coupled) | torch.optim.Adam | | FlashAdamW | Decoupled | torch.optim.AdamW | | FlashSGD | L2 regularization (coupled) | torch.optim.SGD | | FlashSGDW | Decoupled | - | | FlashLion | Decoupled | - |

For decoupled optimizers (FlashAdamW, FlashSGDW, FlashLion), weight decay is applied as a multiplicative factor on the parameters, matching PyTorch's AdamW

Excerpt shown — open the source for the full document.

Notability

notability 5.0/10

New optimization repo from Databricks, moderate stars.