ByteDance-Seed/cryofm-v2
Captured source
source ↗CryoFM2: A Generative Foundation Model for Cryo-EM Densities
Overview
CryoFM2 is a flow-based generative foundation model for cryo-EM density maps. It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks.
The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a Bayesian prior, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification.
Model Details
CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing.
Pre-training Architecture:
Fine-tuning Architecture (for EMhancer/EMReady style post-processing):
Architecture
- Architecture Type: 3D UNet
- Input Size: 64×64×64 voxels
- Input Channels: 2 for pre-trained model, 3 for fine-tuned model
- Output Channels: 1
- Down Blocks: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D
- Up Blocks: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D
- Block Output Channels: (64, 128, 256, 512)
- Layers per Block: 2
- Attention Head Dimension: 8
- Normalization: GroupNorm (32 groups)
- Activation: SiLU
- Time Embedding: Positional encoding
Model Variants
1. cryofm2-pretrain: Unconditional pretrained model for general density map generation 2. cryofm2-emhancer: Fine-tuned model for density map enhancement (EMhancer style) 3. cryofm2-emready: Fine-tuned model for density map enhancement (EMReady style)
Play with CryoFM2
Installation
Before using CryoFM2, you need to set up the environment and install the package. Follow these steps to get started:
# Clone the repository git clone https://github.com/ByteDance-Seed/cryofm.git cd cryofm # Create a new conda environment for CryoFM (recommended) conda create -n cryofm python=3.10 -y conda activate cryofm # Install CryoFM pip install .
Unconditional Generation (Explore Training Data Distribution)
Generate samples from the pretrained model to explore the learned data distribution:
Pretrained Model:
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond
# Update the path to your model directory
model_dir = "path/to/cryofm-v2/cryofm2-pretrain"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
return lit_model(_xt, _t)
# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
out = sample_from_fm(
v_xt_t,
lit_model.noise_scheduler,
method="euler",
num_steps=200,
num_samples=3,
device=lit_model.device,
side_shape=64
)
# Apply normalization if configured
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
# Save generated samples
for i in range(3):
save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5)Fine-tuned Models (EMhancer/EMReady):
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond
# Choose style: "emhancer" or "emready"
style = "emhancer"
model_dir = f"path/to/cryofm-v2/cryofm2-{style}"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
output_tag = 1 if style == "emhancer" else 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
bs = _xt.shape[0]
unconditional_generation_conds = {
"input_cond": None,
"output_cond": torch.tensor([output_tag] * bs).to(device),
"vol_cond": None, # dimension should be [bs, d, h, w]
}
return lit_model(_xt, _t, generation_conds=unconditional_generation_conds)
# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
out = sample_from_fm(
v_xt_t,
lit_model.noise_scheduler,
method="euler",
num_steps=200,
num_samples=3,
device=lit_model.device,
side_shape=64
)
# Apply normalization if configured
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
# Save generated samples
for i in range(3):
save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5)Density Map Modification
CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include:
- denoise: Remove noise from density maps
- inpaint: Fill missing regions (e.g., missing wedge)
- denoise inpaint: Combined denoising and inpainting
- non-uniform weight: Apply non-uniform weighting during reconstruction
Basic Usage:
python -m cryofm.projects.cryofm2.uncond_sampling \ -i1 half_map_1.mrc \ -i2 half_map_2.mrc \ -o ./output \ --model-dir path/to/cryofm-v2/cryofm2-pretrain \ --op denoise \ --norm-grad \ --use-lamb-w
For inpainting tasks, you need to provide a RELION starfile path:
python -m cryofm.projects.cryofm2.uncond_sampling \ -i1 half_map_1.mrc \ -i2 half_map_2.mrc \ -o ./output \ --model-dir path/to/cryofm-v2/cryofm2-pretrain \ --op inpaint \ --data-starfile-path path/to/relion_data.star \ --norm-grad \ --use-lamb-w
Density Map Post-Processing
CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady.
EMhancer Style Enhancement
python -m cryofm.projects.cryofm2.cond_sampling \ -i input_map.mrc \ -o ./output_emhancer \ --model-dir…
Excerpt shown — open the source for the full document.
Notability
notability 2.0/10Very low HF downloads