diff --git a/tools/launcher/common/eagle3/train_eagle_streaming.sh b/tools/launcher/common/eagle3/train_eagle_streaming.sh index 2f1e165062c..65ebc52b429 100755 --- a/tools/launcher/common/eagle3/train_eagle_streaming.sh +++ b/tools/launcher/common/eagle3/train_eagle_streaming.sh @@ -92,6 +92,14 @@ fi pip install --no-cache-dir -e modules/Model-Optimizer/ pip install --no-cache-dir -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt pip install --no-cache-dir 'datasets' 'huggingface-hub>=1.2.1' + +# Some trust_remote_code models pin an older transformers (e.g. MiniMax-M2.7 +# needs 4.57.x whose modeling code is incompatible with the 5.x that the +# requirements pull in). Must run AFTER the requirements install to win. +if [ -n "${OVERRIDE_TRANSFORMERS:-}" ]; then + pip install --no-cache-dir "transformers==${OVERRIDE_TRANSFORMERS}" +fi + export PATH=$PATH:/workspace/.local/bin ################################################################################################### diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml new file mode 100644 index 00000000000..0242db1a914 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml @@ -0,0 +1,138 @@ +# DFlash streaming speculative-decoding training for MiniMax-M2.7 (229B MoE) — MULTI-NODE. +# +# Same streaming transport / dispatch as hf_streaming_eagle3_multi_node.yaml: task_1 +# splits N nodes into K serve replicas + (N-K) DDP trainers via SERVE_NODES; hidden +# states move serve -> trainer over NIXL RDMA. DFlash just consumes a different set of +# captured layers and trains a block-diffusion draft instead of an autoregressive one. +# See common/eagle3/train_eagle_streaming.sh for dispatch, rendezvous, and sharding. +# +# MiniMax-M2.7 specifics: +# - 229B MoE, trust_remote_code, needs TP=4 per serve replica (each replica = 1 node). +# - Custom chat_template_train.jinja for the MiniMax chat format. +# - The DFlash draft is Qwen3-architecture (5 layers); target/draft architectures +# are independent by design. +# - Mask token id 200054 (reserved row in MiniMax-M2.7 embedding). +# - YaRN rope_scaling at export for long context (factor = 196608/4096 = 48). +# +# Capture ids: build_target_layer_ids(num_orig=62, num_draft=5)=[1,16,30,44,59] +# -> +1 for embedding = [2,17,31,45,60], append final layer 62. +# 6 captured = 5 aux layers + final output for self-logit distillation. +# +# 3-step pipeline: +# task_0: Build input conversations (jsonl) +# task_1: Streaming train — 2 serve nodes (TP=4 each) + 2 trainer nodes (4 GPU each) +# task_2: vLLM smoke test with DFlash speculative decoding +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes + +job_name: MiniMax-M2.7-DFlash_streaming_multi_node +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Build input conversations + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + # Step 2: Streaming DFlash training — 2 serve replicas (TP=4) + 2 trainer nodes (4 GPU each). + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - training.output_dir=/scratchspace/dflash + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 + - training.answer_only_loss=true + - training.num_train_epochs=1 + - training.max_steps=5000 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.logging_steps=100 + - training.save_steps=400 + - training.ddp_timeout=3600 + # dp_shard_size=1 stops main.py from creating a ParallelismConfig; the + # accelerate config below owns FSDP2 sharding. + - training.dp_shard_size=1 + - training.bf16=false + # vLLM container has no tensorboard -> init crash. + - training.report_to=none + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + - dflash.dflash_mask_token_id=200054 + # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48). + - dflash.dflash_export_rope_scaling.type=yarn + - dflash.dflash_export_rope_scaling.factor=48.0 + - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 + - dflash.dflash_export_rope_scaling.beta_fast=1.0 + - dflash.dflash_export_rope_scaling.beta_slow=1.0 + - dflash.dflash_export_rope_scaling.mscale=1.0 + - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 + environment: + - HF_MODEL_CKPT: <> + - EAGLE_CAPTURE_IDS: "[2,17,31,45,60,62]" + - SERVE_NODES: "2" + - SERVE_TP: "4" + - STREAMING_NUM_WORKERS: "1" + - EXPORT_EXTRA_ARGS: "--trust_remote_code" + - SERVE_MAX_MODEL_LEN: "4160" + - SERVE_MAX_NUM_SEQS: "4" + - SERVE_GPU_MEM_UTIL: "0.8" + - SERVE_READY_TIMEOUT: "2400" + - SERVE_EXTRA_ARGS: "--trust-remote-code" + - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" + - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS_TF457: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 4 + segment: 4 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:nightly + + # Step 3: vLLM smoke test with DFlash speculative decoding + task_2: + script: common/specdec/vllm_smoke_test.sh + environment: + - HF_MODEL_CKPT: <> + - DRAFT_CKPT_DIR: /scratchspace/dflash + - SPEC_METHOD: "dflash" + - NUM_SPEC_TOKENS: "7" + - TP_SIZE: "4" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly