ModelByteDance (Doubao/Seed)ByteDance (Doubao/Seed)published Dec 24, 2025seen 5d

ByteDance-Seed/cryofm-v2

Open original ↗

Captured source

source ↗
published Dec 24, 2025seen 5dcaptured 12hhttp 200method plaintask otherlicense apache-2.0library cryofmdownloads 16likes 6

CryoFM2: A Generative Foundation Model for Cryo-EM Densities

Overview

CryoFM2 is a flow-based generative foundation model for cryo-EM density maps. It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks.

The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a Bayesian prior, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification.

Model Details

CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing.

Pre-training Architecture:

Fine-tuning Architecture (for EMhancer/EMReady style post-processing):

Architecture

  • Architecture Type: 3D UNet
  • Input Size: 64×64×64 voxels
  • Input Channels: 2 for pre-trained model, 3 for fine-tuned model
  • Output Channels: 1
  • Down Blocks: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D
  • Up Blocks: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D
  • Block Output Channels: (64, 128, 256, 512)
  • Layers per Block: 2
  • Attention Head Dimension: 8
  • Normalization: GroupNorm (32 groups)
  • Activation: SiLU
  • Time Embedding: Positional encoding

Model Variants

1. cryofm2-pretrain: Unconditional pretrained model for general density map generation 2. cryofm2-emhancer: Fine-tuned model for density map enhancement (EMhancer style) 3. cryofm2-emready: Fine-tuned model for density map enhancement (EMReady style)

Play with CryoFM2

Installation

Before using CryoFM2, you need to set up the environment and install the package. Follow these steps to get started:

# Clone the repository
git clone https://github.com/ByteDance-Seed/cryofm.git
cd cryofm

# Create a new conda environment for CryoFM (recommended)
conda create -n cryofm python=3.10 -y
conda activate cryofm

# Install CryoFM
pip install .

Unconditional Generation (Explore Training Data Distribution)

Generate samples from the pretrained model to explore the learned data distribution:

Pretrained Model:

import torch
from mmengine import Config

from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond

# Update the path to your model directory
model_dir = "path/to/cryofm-v2/cryofm2-pretrain"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
return lit_model(_xt, _t)

# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
out = sample_from_fm(
v_xt_t,
lit_model.noise_scheduler,
method="euler",
num_steps=200,
num_samples=3,
device=lit_model.device,
side_shape=64
)
# Apply normalization if configured
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated samples
for i in range(3):
save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5)

Fine-tuned Models (EMhancer/EMReady):

import torch
from mmengine import Config

from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond

# Choose style: "emhancer" or "emready"
style = "emhancer"
model_dir = f"path/to/cryofm-v2/cryofm2-{style}"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
output_tag = 1 if style == "emhancer" else 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
bs = _xt.shape[0]
unconditional_generation_conds = {
"input_cond": None,
"output_cond": torch.tensor([output_tag] * bs).to(device),
"vol_cond": None, # dimension should be [bs, d, h, w]
}
return lit_model(_xt, _t, generation_conds=unconditional_generation_conds)

# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
out = sample_from_fm(
v_xt_t,
lit_model.noise_scheduler,
method="euler",
num_steps=200,
num_samples=3,
device=lit_model.device,
side_shape=64
)
# Apply normalization if configured
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated samples
for i in range(3):
save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5)

Density Map Modification

CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include:

  • denoise: Remove noise from density maps
  • inpaint: Fill missing regions (e.g., missing wedge)
  • denoise inpaint: Combined denoising and inpainting
  • non-uniform weight: Apply non-uniform weighting during reconstruction

Basic Usage:

python -m cryofm.projects.cryofm2.uncond_sampling \
-i1 half_map_1.mrc \
-i2 half_map_2.mrc \
-o ./output \
--model-dir path/to/cryofm-v2/cryofm2-pretrain \
--op denoise \
--norm-grad \
--use-lamb-w

For inpainting tasks, you need to provide a RELION starfile path:

python -m cryofm.projects.cryofm2.uncond_sampling \
-i1 half_map_1.mrc \
-i2 half_map_2.mrc \
-o ./output \
--model-dir path/to/cryofm-v2/cryofm2-pretrain \
--op inpaint \
--data-starfile-path path/to/relion_data.star \
--norm-grad \
--use-lamb-w

Density Map Post-Processing

CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady.

EMhancer Style Enhancement

python -m cryofm.projects.cryofm2.cond_sampling \
-i input_map.mrc \
-o ./output_emhancer \
--model-dir…

Excerpt shown — open the source for the full document.

Notability

notability 2.0/10

Very low HF downloads