Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docker/common/uv-pytorch.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4036,7 +4036,7 @@ requires-dist = [
{ name = "torchvision", marker = "sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'diffusion'", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchvision", marker = "sys_platform == 'linux' and extra == 'diffusion'", index = "https://download.pytorch.org/whl/cu130" },
{ name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'cuda'", specifier = ">=2.14.1" },
{ name = "transformers", specifier = "==5.8.1" },
{ name = "transformers", specifier = "==5.12.0" },
{ name = "wandb", specifier = ">=0.26.1" },
]
provides-extras = ["diffusion", "diffusion-kernels", "cuda", "cuda-source", "extra", "fa", "fla", "delta-databricks", "moe", "vlm", "cli", "s3", "msc", "all"]
Expand Down Expand Up @@ -7579,7 +7579,7 @@ wheels = [

[[package]]
name = "transformers"
version = "5.8.1"
version = "5.12.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
Expand All @@ -7593,9 +7593,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typer" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e7/e6/4134ea2fbea322cddc7ffc94a0d8ee47fe32ce8e876b320cd37d88edfc4d/transformers-5.8.1.tar.gz", hash = "sha256:4dd5b6de4105725104d84fd6abd74b305f4debfc251b38c648ee5dd087cf543b", size = 8532019, upload-time = "2026-05-13T03:21:57.234Z" }
sdist = { url = "https://files.pythonhosted.org/packages/0c/f9/4552e2ba55db1c943aea0d4c09a32e9cbe5445b9eabe9856900de503dc8f/transformers-5.12.0.tar.gz", hash = "sha256:f0cf42ae1464c2eb41e7e0e66d7fd4b66145f48af17093b4cc0b2e9781faa7f4", size = 8923020, upload-time = "2026-06-12T14:39:20.43Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fc/b1/8be7e7ef0b5200491312201918b6125ef9c9df9dd0f0240ccef9ac824e6b/transformers-5.8.1-py3-none-any.whl", hash = "sha256:5340fb95962162cdfdae5cc91d7f8fedd92ed75216c1154c5e1f590fcf56dd0e", size = 10632882, upload-time = "2026-05-13T03:21:52.876Z" },
{ url = "https://files.pythonhosted.org/packages/ff/1f/d385913c38e900d23b728a4188fee625f2fccb306aeebc59be9d91404a5a/transformers-5.12.0-py3-none-any.whl", hash = "sha256:500be9eb644ede81c3103eee7687fc36d05dd75d1c76686c3820b26396fe7c7c", size = 11150246, upload-time = "2026-06-12T14:39:17.009Z" },
]

[[package]]
Expand Down
18 changes: 12 additions & 6 deletions docs/guides/gradient-checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,21 @@ Use `true` or `full` for full activation checkpointing. Use `selective` for PyTo
> **Note (MoE/expert parallelism):** Selective AC is designed for dense transformers and generally does **not** help Mixture-of-Experts models with expert parallelism. In an MoE block the experts dominate the cost (they are cheap to recompute but expensive to store), and the expert-parallel dispatch/communication is opaque to the selective policy, so it is recomputed regardless. As a result, selective AC tends to add activation memory without a corresponding speedup for MoE, matching what reference implementations such as TorchTitan observe. Prefer **full** activation checkpointing (`true`/`full`) for MoE; selective remains supported for MoE and FSDP2 as an opt-in.

### Configure Programmatically

```python
from nemo_automodel.components.distributed.config import FSDP2Config
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
from nemo_automodel import NeMoAutoModelForCausalLM
from nemo_automodel.components.distributed.config import DistributedSetup

config = FSDP2Config(activation_checkpointing=True)
# Use activation_checkpointing="selective" for FSDP2 selective checkpointing.
# device_mesh is created elsewhere (e.g. by the recipe via setup_distributed)
manager = FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh)
model = manager.parallelize(model)
distributed_setup = DistributedSetup.build(
strategy="fsdp2",
activation_checkpointing=True,
)

model = NeMoAutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
distributed_setup=distributed_setup,
)
```

## Combine with Linear-Cut Cross-Entropy (LC-CE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

"""Extract per-layer activations from NeMo AutoModel using the training code path.

Uses the same distributed setup as training (EP, FSDP, backend config) and
registers forward hooks on decoder layers to capture hidden states. Saves
Uses the same distributed environment and mesh context as training (EP, FSDP,
backend config) and registers forward hooks on decoder layers to capture hidden states. Saves
activations and final logits for comparison against HF Transformers.

Run via torchrun:
Expand Down Expand Up @@ -84,33 +84,29 @@ def main():
sys.argv = ["extract_nemo_activations.py", "--config", args.config] + extra
cfg = parse_args_and_load_config()

# --- Distributed setup ---
# --- Distributed environment and mesh context ---
from nemo_automodel._transformers.utils import apply_cache_compatibility_patches
from nemo_automodel.components.loggers.log_utils import setup_logging
from nemo_automodel.recipes._dist_setup import setup_distributed
from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config
from nemo_automodel.recipes.llm.train_ft import build_distributed, build_model
from nemo_automodel.shared.te_patches import apply_te_patches

dist_env = build_distributed(cfg.get("dist_env", {}))
setup_logging()
apply_cache_compatibility_patches()
apply_te_patches()
dist_setup = setup_distributed(cfg, world_size=dist_env.world_size)
distributed_setup = create_distributed_setup_from_config(cfg, world_size=dist_env.world_size)
mesh_context = distributed_setup.mesh_context

if dist_setup.cp_size > 1 and cfg.get("model.backend.rope_fusion", False):
if mesh_context.cp_size > 1 and cfg.get("model.backend.rope_fusion", False):
cfg.model.backend.rope_fusion = False

# --- Build model ---
model = build_model(
cfg.model,
cfg_peft=None,
seed=cfg.get("seed", 42),
device_mesh=dist_setup.device_mesh,
moe_mesh=dist_setup.moe_mesh,
distributed_config=dist_setup.strategy_config,
pipeline_config=dist_setup.pipeline_config,
cfg_moe=dist_setup.moe_config,
activation_checkpointing=dist_setup.activation_checkpointing,
distributed_setup=distributed_setup,
)
model.eval()

Expand Down
97 changes: 97 additions & 0 deletions examples/diffusion/finetune/flux2_t2i_flow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
seed: 42

wandb:
project: flux2-finetuning
mode: online
name: flux2_finetune_run

dist_env:
backend: "nccl"
init_method: "env://"
timeout_minutes: 30

model:
pretrained_model_name_or_path: "black-forest-labs/FLUX.2-dev"
mode: "finetune"
cache_dir: null
attention_backend: "flash"

optim:
learning_rate: 1e-5
optimizer:
weight_decay: 1e-4
betas: [0.9, 0.999]
foreach: false
fused: true

performance:
check_loss: false
grad_clip_foreach: true

lr_scheduler:
lr_decay_style: cosine
lr_warmup_steps: 500
min_lr: 0.0

# Adjust dp_size to the total number of GPUs
fsdp:
dp_size: 8
tp_size: 1
cp_size: 1
pp_size: 1
activation_checkpointing: false
cpu_offload: false
defer_fsdp_grad_sync: true
enable_compile: false
enable_fsdp2_prefetch: true
fsdp2_backward_prefetch_depth: 3
fsdp2_forward_prefetch_depth: 2

flow_matching:
adapter_type: "flux2"
adapter_kwargs:
guidance_scale: 3.5
use_guidance_embeds: true
# DreamBooth default: weighting_scheme="none" = uniform u, no loss weighting.
timestep_sampling: "uniform"
logit_mean: 0.0
logit_std: 1.0
# shift=3.0 matches FlowMatchEulerDiscreteScheduler shift for FLUX.2-dev.
flow_shift: 3.0
mix_uniform_ratio: 0.0
sigma_min: 0.0
sigma_max: 1.0
num_train_timesteps: 1000
cfg_dropout_prob: 0.0
i2v_prob: 0.0
use_loss_weighting: false
log_interval: 100
summary_log_interval: 10

step_scheduler:
num_epochs: 1000
local_batch_size: 1
global_batch_size: 8
ckpt_every_steps: 500
log_every: 10
# max_steps: null # Set to limit training to a specific number of steps
save_checkpoint_every_epoch: false

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_text_to_image_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
train_text_encoder: false
num_workers: 2
base_resolution: [512, 512]
dynamic_batch_size: false
shuffle: true
drop_last: false

checkpoint:
enabled: true
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
model_save_format: safetensors
save_consolidated: final
diffusers_compatible: false
restore_from: null
124 changes: 124 additions & 0 deletions examples/diffusion/finetune/flux2_t2i_flow_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
seed: 42

wandb:
project: flux2-finetuning
mode: online
name: flux2_lora_run

dist_env:
backend: "nccl"
init_method: "env://"

model:
pretrained_model_name_or_path: "black-forest-labs/FLUX.2-dev"
# Required when peft block is present — selects pre-inject hook and is
# passed to auto_diffusion_pipeline for model-specific setup.
model_type: "flux2"
mode: "finetune"
cache_dir: null
attention_backend: "flash"

# ── LoRA / PEFT configuration ─────────────────────────────────────────────────
# Defaults aligned with diffusers/examples/dreambooth/train_dreambooth_lora_flux2.py
# Target modules cover both double-stream (transformer_blocks) and
# single-stream (single_transformer_blocks) attention projections.
# Run `for n, m in model.named_modules(): print(n)` to verify exact names.
peft:
_target_: nemo_automodel.components._peft.lora.PeftConfig
dim: 4
alpha: 4
dropout: 0.0
target_modules:
- "*.to_q"
- "*.to_k"
- "*.to_v"
- "*.to_out.0"
- "*.to_qkv_mlp_proj"
- "*.single_transformer_blocks.*.attn.to_out"

optim:
learning_rate: 1e-4
optimizer:
weight_decay: 1e-4
betas: [0.9, 0.999]
foreach: false
fused: true

performance:
check_loss: false
grad_clip_foreach: true

lr_scheduler:
lr_decay_style: cosine
lr_warmup_steps: 500
min_lr: 0.0

# FSDP2 with param_dtype=None when LoRA enabled:
# base weights (bf16): sharded, no grad comms
# LoRA weights (fp32): sharded, grad allreduced by FSDP2
fsdp:
dp_size: 8
tp_size: 1
cp_size: 1
pp_size: 1
activation_checkpointing: false
cpu_offload: false
defer_fsdp_grad_sync: true
enable_compile: false
enable_fsdp2_prefetch: true
fsdp2_backward_prefetch_depth: 3
fsdp2_forward_prefetch_depth: 2

flow_matching:
adapter_type: "flux2"
adapter_kwargs:
guidance_scale: 3.5
use_guidance_embeds: true
# DreamBooth default: weighting_scheme="none" = uniform u, no loss weighting.
# Alternative: timestep_sampling: "logit_normal" (SD3 paper recommendation for fine-tuning).
timestep_sampling: "uniform"
logit_mean: 0.0
logit_std: 1.0
# shift=3.0 matches FlowMatchEulerDiscreteScheduler shift for FLUX.2-dev.
# Dynamic shifting applies at inference per resolution; training uses the base shift.
flow_shift: 3.0
mix_uniform_ratio: 0.0
sigma_min: 0.0
sigma_max: 1.0
num_train_timesteps: 1000
# DreamBooth does not apply CFG dropout during training.
# Set > 0 only if you want unconditional guidance support.
cfg_dropout_prob: 0.0
i2v_prob: 0.0
use_loss_weighting: false
log_interval: 100
summary_log_interval: 10

step_scheduler:
num_epochs: 1000
local_batch_size: 1
global_batch_size: 4
ckpt_every_steps: 500
log_every: 10
# max_steps: null # Set to limit training to a specific number of steps
save_checkpoint_every_epoch: false

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_text_to_image_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
train_text_encoder: false
num_workers: 0
base_resolution: [512, 512]
dynamic_batch_size: false
shuffle: true
drop_last: false

checkpoint:
enabled: true
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
# LoRA saves adapter_model.safetensors + adapter_config.json
# (ignores model_save_format — PEFT format always used for LoRA)
model_save_format: torch_save
save_consolidated: false
restore_from: null
29 changes: 29 additions & 0 deletions examples/diffusion/generate/configs/generate_flux2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
model:
pretrained_model_name_or_path: "black-forest-labs/FLUX.2-dev"
checkpoint: null
lora_weights: null
lora_scale: 1.0


inference:
num_inference_steps: 28
guidance_scale: 3.5
height: 512
width: 512
dtype: "bfloat16"
max_samples: 10
prompts:
- "A cat sitting on a windowsill watching the rain"
pipeline_kwargs: {}

output:
output_dir: "./inference_outputs"

distributed: null

vae:
enable_slicing: false
enable_tiling: false
enable_cpu_offload: true

seed: 42
7 changes: 2 additions & 5 deletions examples/diffusion/generate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# Pipeline class name -> output type mapping
_PIPELINE_OUTPUT_TYPES = {
"FluxPipeline": "image",
"Flux2Pipeline": "image",
"QwenImagePipeline": "image",
"WanPipeline": "video",
"HunyuanVideoPipeline": "video",
Expand Down Expand Up @@ -350,7 +351,7 @@ def load_lora_weights_into_pipeline(pipe, cfg):


def _load_sharded_fsdp_checkpoint(transformer, sharded_dir, torch_dtype=torch.bfloat16):
"""Load sharded FSDP/DCP checkpoint into a transformer module.
"""Load sharded FSDP1 .distcp checkpoint into a transformer module.

Creates a temporary gloo process group for single-GPU loading if
torch.distributed is not already initialized.
Expand Down Expand Up @@ -379,18 +380,14 @@ def _load_sharded_fsdp_checkpoint(transformer, sharded_dir, torch_dtype=torch.bf
try:
transformer.to(device="cuda", dtype=torch_dtype)
fsdp_transformer = FSDP(transformer, use_orig_params=True)

FSDP.set_state_dict_type(
fsdp_transformer,
StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
)

model_state = fsdp_transformer.state_dict()
dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir))
fsdp_transformer.load_state_dict(model_state)

# Unwrap back to the original module for inference
return fsdp_transformer.module
finally:
if init_dist:
Expand Down
Loading
Loading