Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tools/launcher/common/eagle3/train_eagle_streaming.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ fi
pip install --no-cache-dir -e modules/Model-Optimizer/
pip install --no-cache-dir -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt
pip install --no-cache-dir 'datasets' 'huggingface-hub>=1.2.1'

# Some trust_remote_code models pin an older transformers (e.g. MiniMax-M2.7
# needs 4.57.x whose modeling code is incompatible with the 5.x that the
# requirements pull in). Must run AFTER the requirements install to win.
if [ -n "${OVERRIDE_TRANSFORMERS:-}" ]; then
pip install --no-cache-dir "transformers==${OVERRIDE_TRANSFORMERS}"
fi

export PATH=$PATH:/workspace/.local/bin

###################################################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# DFlash streaming speculative-decoding training for MiniMax-M2.7 (229B MoE) — MULTI-NODE.
#
# Same streaming transport / dispatch as hf_streaming_eagle3_multi_node.yaml: task_1
# splits N nodes into K serve replicas + (N-K) DDP trainers via SERVE_NODES; hidden
# states move serve -> trainer over NIXL RDMA. DFlash just consumes a different set of
# captured layers and trains a block-diffusion draft instead of an autoregressive one.
# See common/eagle3/train_eagle_streaming.sh for dispatch, rendezvous, and sharding.
#
# MiniMax-M2.7 specifics:
# - 229B MoE, trust_remote_code, needs TP=4 per serve replica (each replica = 1 node).
# - Custom chat_template_train.jinja for the MiniMax chat format.
# - The DFlash draft is Qwen3-architecture (5 layers); target/draft architectures
# are independent by design.
# - Mask token id 200054 (reserved row in MiniMax-M2.7 embedding).
# - YaRN rope_scaling at export for long context (factor = 196608/4096 = 48).
#
# Capture ids: build_target_layer_ids(num_orig=62, num_draft=5)=[1,16,30,44,59]
# -> +1 for embedding = [2,17,31,45,60], append final layer 62.
# 6 captured = 5 aux layers + final output for self-logit distillation.
#
# 3-step pipeline:
# task_0: Build input conversations (jsonl)
# task_1: Streaming train — 2 serve nodes (TP=4 each) + 2 trainer nodes (4 GPU each)
# task_2: vLLM smoke test with DFlash speculative decoding
#
# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
#
# Usage:
# uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes
# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes

job_name: MiniMax-M2.7-DFlash_streaming_multi_node
pipeline:
allow_to_fail: false
skip: false
note:

global_vars:
hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7

# Step 1: Build input conversations
task_0:
script: common/eagle3/make_dataset.sh
args:
- -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml
- --full-conversations
slurm_config:
_factory_: "slurm_factory"
nodes: 1
ntasks_per_node: 1
gpus_per_node: 4
container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10

# Step 2: Streaming DFlash training — 2 serve replicas (TP=4) + 2 trainer nodes (4 GPU each).
task_1:
script: common/eagle3/train_eagle_streaming.sh
args:
- --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
- model.model_name_or_path=<<global_vars.hf_model>>
- model.use_fake_base_for_offline=true
- model.trust_remote_code=true
- data.mode=streaming
- data.data_path=/scratchspace/data/train.jsonl
- data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
- training.output_dir=/scratchspace/dflash
- training.training_seq_len=4096
- training.disable_tqdm=true
- training.ar_validate_steps=500000
- training.answer_only_loss=true
- training.num_train_epochs=1
- training.max_steps=5000
- training.per_device_train_batch_size=2
- training.learning_rate=1.2e-3
- training.warmup_steps=100
- training.logging_steps=100
- training.save_steps=400
- training.ddp_timeout=3600
# dp_shard_size=1 stops main.py from creating a ParallelismConfig; the
# accelerate config below owns FSDP2 sharding.
- training.dp_shard_size=1
- training.bf16=false
# vLLM container has no tensorboard -> init crash.
- training.report_to=none
- dflash.dflash_self_logit_distillation=true
- dflash.dflash_block_size=8
- dflash.dflash_num_anchors=512
- dflash.dflash_loss_decay_factor=4.0
- dflash.dflash_architecture_config.num_hidden_layers=5
- dflash.dflash_mask_token_id=200054
# YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48).
- dflash.dflash_export_rope_scaling.type=yarn
- dflash.dflash_export_rope_scaling.factor=48.0
- dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096
- dflash.dflash_export_rope_scaling.beta_fast=1.0
- dflash.dflash_export_rope_scaling.beta_slow=1.0
- dflash.dflash_export_rope_scaling.mscale=1.0
- dflash.dflash_export_rope_scaling.mscale_all_dim=1.0
environment:
- HF_MODEL_CKPT: <<global_vars.hf_model>>
- EAGLE_CAPTURE_IDS: "[2,17,31,45,60,62]"
- SERVE_NODES: "2"
- SERVE_TP: "4"
- STREAMING_NUM_WORKERS: "1"
- EXPORT_EXTRA_ARGS: "--trust_remote_code"
- SERVE_MAX_MODEL_LEN: "4160"
- SERVE_MAX_NUM_SEQS: "4"
- SERVE_GPU_MEM_UTIL: "0.8"
- SERVE_READY_TIMEOUT: "2400"
- SERVE_EXTRA_ARGS: "--trust-remote-code"
- VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200"
- VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200"
- OVERRIDE_TRANSFORMERS: "4.57.1"
- ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml
- PATCH_FSDP2_BUFFERS_TF457: "1"
- MIXED_PRECISION: "no"
slurm_config:
_factory_: "slurm_factory"
nodes: 4
segment: 4
ntasks_per_node: 1
gpus_per_node: 4
container: vllm/vllm-openai:nightly

# Step 3: vLLM smoke test with DFlash speculative decoding
task_2:
script: common/specdec/vllm_smoke_test.sh
environment:
- HF_MODEL_CKPT: <<global_vars.hf_model>>
- DRAFT_CKPT_DIR: /scratchspace/dflash
- SPEC_METHOD: "dflash"
- NUM_SPEC_TOKENS: "7"
- TP_SIZE: "4"
slurm_config:
_factory_: "slurm_factory"
nodes: 1
ntasks_per_node: 1
gpus_per_node: 8
container: vllm/vllm-openai:nightly
Loading