RepoByteDance (Doubao/Seed)ByteDance (Doubao/Seed)published Dec 1, 2025seen 5d

ByteDance-Seed/Adversarial-Flow-Models

Python

Open original ↗

Captured source

source ↗

ByteDance-Seed/Adversarial-Flow-Models

Language: Python

License: MIT

Stars: 79

Forks: 0

Open issues: 1

Created: 2025-12-01T07:52:51Z

Pushed: 2026-04-18T03:57:11Z

Default branch: main

Fork: no

Archived: no

README:

Adversarial Flow Models

This repository contains the official PyTorch implementation of both discrete and continuous Adversarial Flow Models.

> **Adversarial Flow Models**

> Shanchuan Lin, Ceyuan Yang, Zhijie Lin, Hao Chen, Haoqi Fan > ByteDance Seed

> **Continuous Adversarial Flow Models**

> Shanchuan Lin, Ceyuan Yang, Zhijie Lin, Hao Chen, Haoqi Fan > ByteDance Seed

Colab Notebooks

AFMs

Train

1. Install requirements pip install -r requirements_afm.txt. 2. Download VAE and other misc checkpoints to the root directory. 3. Download dit.py from the original DiT repo and place it under models/afm/dit/dit.py. 4. Configure your dataset. Instruction is provided in the next section. 5. Run the training configurations provided in configs/train/afm.

  • Replace TORCHRUN with your torchrun command with your GPU configuration.
  • Make sure exp.gpu is equal to your total amount of GPUs for the current per-rank batch size calculation.
  • You can set smaller exp.bsz for local debugging.
  • The training schedule is provided in Table 11 of the AFM paper. The current approach still requires more manual intervention. This is a limitation we hope to improve in future work.
TORCHRUN main.py configs/train/train_1nfe.yaml

Evaluate

1. Download pre-trained AFMs checkpoints, or use your own. 2. Generate 50K samples for FID evaluation.

TORCHRUN main.py configs/generate/afm/generate_1nfe.yaml

3. Use /misc/pack_npz.py to pack npz. 4. Use ADM evaluation suite to evaluate FID.

CAFMs

Train

1. Install requirements pip install -r requirements_cafm.txt. 2. Download VAE and other misc checkpoints to the root directory. 3. Download model files.

  • Download sit.py from the original SiT repo and place it under models/cafm/sit/sit.py.
  • Download jit.py from the original JiT repo and place it under models/cafm/jit/jit.py.
  • No need to download model code for Z-Image.

4. Configure your dataset. Instruction is provided in the next section. 5. Download pre-trained checkpoints.

6. Run the training configurations provided in configs/train/cafm.

  • Replace TORCHRUN with your torchrun command with your GPU configuration.
  • Make sure exp.gpu is equal to your total amount of GPUs for the current per-rank batch size calculation.
  • You can set smaller exp.bsz for local debugging.
TORCHRUN main.py configs/train/cafm/train_cafm_sit.yaml
TORCHRUN main.py configs/train/cafm/train_cafm_jit.yaml
TORCHRUN main.py configs/train/cafm/train_cafm_zimage.yaml

Evaluate

1. Download pre-trained CAFMs checkpoints, or use your own. 2. We do not provide generation/evaluation code. Please use SiT/JiT codebase for generation and evaluation.

  • You may need /misc/convert_to_jit_format.py to convert our JiT saved ckpt to their format.
  • Note that the FID logged by our training script is only a rough estimate. You will get better FID using their official evaluation code!

Dataloading

For our official training we pack imagenet and t2i datasets into parquet format. The dataloading code is provided in /data only for reference purposes. You can implement your own dataset loading logic.

For ImageNet, implement it as a IterableDataset with a forever loop that returns a dictionary with keys image and label. The image should be a PyTorch tensor of shape (3, H, W) with range [0, 1]. The dataset class should accept our transform to handle resize, cropping, and normalization to [-1, 1]. The label should be the class index with range [0, 999].

For ImageNet, CAFM SiT training also supports using offline dataloading. It a dictionary with keys latent and label. The latent should be a PyTorch tensor of shape (4, 32, 32). The train script will automatically skip VAE encoding and use the offline latents.

For T2I, implement it as a IterableDataset with a forever loop that returns a dictionary with keys image and text. The image should be a tensor of shape (3, H, W) and we support batching of different aspect ratios. The text should be a string.

Note that the IterableDataset must internally check the current rank and worker id to handle distributed partitioning.…

Excerpt shown — open the source for the full document.

Notability

notability 5.0/10

New model repo from ByteDance, low stars