Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# DFlash streaming speculative-decoding training for MiniMax-M2.7 (229B MoE) — MULTI-NODE.
#
# Same streaming transport / dispatch as hf_streaming_eagle3_multi_node.yaml: task_1
# splits N nodes into K serve replicas + (N-K) DDP trainers via SERVE_NODES; hidden
# states move serve -> trainer over NIXL RDMA. DFlash just consumes a different set of
# captured layers and trains a block-diffusion draft instead of an autoregressive one.
# See common/eagle3/train_eagle_streaming.sh for dispatch, rendezvous, and sharding.
#
# MiniMax-M2.7 specifics:
# - 229B MoE, trust_remote_code, needs TP=4 per serve replica (each replica = 1 node).
# - Custom chat_template_train.jinja for the Ministral chat format.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
# - The DFlash draft is Qwen3-architecture (5 layers); target/draft architectures
# are independent by design.
# - Mask token id 200054 (reserved row in MiniMax-M2.7 embedding).
# - YaRN rope_scaling at export for long context (factor = 196608/4096 = 48).
#
# Capture ids: build_target_layer_ids(num_orig=64, num_draft=5)=[1,16,31,46,61]
# -> +1 for embedding = [2,17,32,47,62], append final layer 64.
# 6 captured = 5 aux layers + final output for self-logit distillation.
#
# 3-step pipeline:
# task_0: Build input conversations (jsonl)
# task_1: Streaming train — 2 serve nodes (TP=4 each) + 2 trainer nodes (4 GPU each)
# task_2: vLLM smoke test with DFlash speculative decoding
#
# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
#
# Usage:
# uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes
# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes

job_name: MiniMax-M2.7-DFlash_streaming_multi_node
pipeline:
allow_to_fail: false
skip: false
note:

global_vars:
hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7

# Step 1: Build input conversations
task_0:
script: common/eagle3/make_dataset.sh
args:
- -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml
- --full-conversations
slurm_config:
_factory_: "slurm_factory"
nodes: 1
ntasks_per_node: 1
gpus_per_node: 4
container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10

# Step 2: Streaming DFlash training — 2 serve replicas (TP=4) + 2 trainer nodes (4 GPU each).
task_1:
script: common/eagle3/train_eagle_streaming.sh
args:
- --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
- model.model_name_or_path=<<global_vars.hf_model>>
- model.use_fake_base_for_offline=true
- model.trust_remote_code=true
- data.mode=streaming
- data.data_path=/scratchspace/data/train.jsonl
- data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
- training.output_dir=/scratchspace/dflash
- training.training_seq_len=4096
- training.disable_tqdm=true
- training.ar_validate_steps=500000
- training.answer_only_loss=true
- training.num_train_epochs=1
- training.max_steps=5000
- training.per_device_train_batch_size=2
- training.learning_rate=1.2e-3
- training.warmup_steps=100
- training.logging_steps=100
- training.save_steps=400
- training.ddp_timeout=3600
# dp_shard_size=1 stops main.py from creating a ParallelismConfig; the
# accelerate config below owns FSDP2 sharding.
- training.dp_shard_size=1
- training.bf16=false
# vLLM container has no tensorboard -> init crash.
- training.report_to=none
- dflash.dflash_self_logit_distillation=true
- dflash.dflash_block_size=8
- dflash.dflash_num_anchors=512
- dflash.dflash_loss_decay_factor=4.0
- dflash.dflash_architecture_config.num_hidden_layers=5
- dflash.dflash_mask_token_id=200054
# YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48).
- dflash.dflash_export_rope_scaling.type=yarn
- dflash.dflash_export_rope_scaling.factor=48.0
- dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096
- dflash.dflash_export_rope_scaling.beta_fast=1.0
- dflash.dflash_export_rope_scaling.beta_slow=1.0
- dflash.dflash_export_rope_scaling.mscale=1.0
- dflash.dflash_export_rope_scaling.mscale_all_dim=1.0
environment:
- HF_MODEL_CKPT: <<global_vars.hf_model>>
- EAGLE_CAPTURE_IDS: "[2,17,32,47,62,64]"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: minimax seems to have 62 layers. What does 64 means here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — MiniMax-M2.7 has 62 hidden layers, not 64. I miscounted from the prior DFlash M3 work (which is 64 layers). Fixed in b71297f: recalculated EAGLE_CAPTURE_IDS from build_target_layer_ids(62, 5)[2,17,31,45,60,62].

- SERVE_NODES: "2"
- SERVE_TP: "4"
- STREAMING_NUM_WORKERS: "1"
- EXPORT_EXTRA_ARGS: "--trust_remote_code"
- SERVE_MAX_MODEL_LEN: "4160"
- SERVE_MAX_NUM_SEQS: "4"
- SERVE_GPU_MEM_UTIL: "0.8"
- SERVE_READY_TIMEOUT: "2400"
- SERVE_EXTRA_ARGS: "--trust-remote-code"
- VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200"
- VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200"
- OVERRIDE_TRANSFORMERS: "4.57.1"
- ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml
- PATCH_FSDP2_BUFFERS_TF457: "1"
- MIXED_PRECISION: "no"
slurm_config:
_factory_: "slurm_factory"
nodes: 4
segment: 4
ntasks_per_node: 1
gpus_per_node: 4
container: vllm/vllm-openai:nightly

# Step 3: vLLM smoke test with DFlash speculative decoding
task_2:
script: common/specdec/vllm_smoke_test.sh
environment:
- HF_MODEL_CKPT: <<global_vars.hf_model>>
- DRAFT_CKPT_DIR: /scratchspace/dflash
- SPEC_METHOD: "dflash"
- NUM_SPEC_TOKENS: "7"
- TP_SIZE: "4"
slurm_config:
_factory_: "slurm_factory"
nodes: 1
ntasks_per_node: 1
gpus_per_node: 8
container: vllm/vllm-openai:nightly
Loading