-
Notifications
You must be signed in to change notification settings - Fork 464
specdec(recipe): add MiniMax-M2.7-DFlash streaming multi-node pipeline #1835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
3721cc1
b71297f
53bfa52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| # DFlash streaming speculative-decoding training for MiniMax-M2.7 (229B MoE) — MULTI-NODE. | ||
| # | ||
| # Same streaming transport / dispatch as hf_streaming_eagle3_multi_node.yaml: task_1 | ||
| # splits N nodes into K serve replicas + (N-K) DDP trainers via SERVE_NODES; hidden | ||
| # states move serve -> trainer over NIXL RDMA. DFlash just consumes a different set of | ||
| # captured layers and trains a block-diffusion draft instead of an autoregressive one. | ||
| # See common/eagle3/train_eagle_streaming.sh for dispatch, rendezvous, and sharding. | ||
| # | ||
| # MiniMax-M2.7 specifics: | ||
| # - 229B MoE, trust_remote_code, needs TP=4 per serve replica (each replica = 1 node). | ||
| # - Custom chat_template_train.jinja for the Ministral chat format. | ||
| # - The DFlash draft is Qwen3-architecture (5 layers); target/draft architectures | ||
| # are independent by design. | ||
| # - Mask token id 200054 (reserved row in MiniMax-M2.7 embedding). | ||
| # - YaRN rope_scaling at export for long context (factor = 196608/4096 = 48). | ||
| # | ||
| # Capture ids: build_target_layer_ids(num_orig=64, num_draft=5)=[1,16,31,46,61] | ||
| # -> +1 for embedding = [2,17,32,47,62], append final layer 64. | ||
| # 6 captured = 5 aux layers + final output for self-logit distillation. | ||
| # | ||
| # 3-step pipeline: | ||
| # task_0: Build input conversations (jsonl) | ||
| # task_1: Streaming train — 2 serve nodes (TP=4 each) + 2 trainer nodes (4 GPU each) | ||
| # task_2: vLLM smoke test with DFlash speculative decoding | ||
| # | ||
| # Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) | ||
| # | ||
| # Usage: | ||
| # uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes | ||
| # uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_streaming_dflash_multi_node.yaml --yes | ||
|
|
||
| job_name: MiniMax-M2.7-DFlash_streaming_multi_node | ||
| pipeline: | ||
| allow_to_fail: false | ||
| skip: false | ||
| note: | ||
|
|
||
| global_vars: | ||
| hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 | ||
|
|
||
| # Step 1: Build input conversations | ||
| task_0: | ||
| script: common/eagle3/make_dataset.sh | ||
| args: | ||
| - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml | ||
| - --full-conversations | ||
| slurm_config: | ||
| _factory_: "slurm_factory" | ||
| nodes: 1 | ||
| ntasks_per_node: 1 | ||
| gpus_per_node: 4 | ||
| container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 | ||
|
|
||
| # Step 2: Streaming DFlash training — 2 serve replicas (TP=4) + 2 trainer nodes (4 GPU each). | ||
| task_1: | ||
| script: common/eagle3/train_eagle_streaming.sh | ||
| args: | ||
| - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml | ||
| - model.model_name_or_path=<<global_vars.hf_model>> | ||
| - model.use_fake_base_for_offline=true | ||
| - model.trust_remote_code=true | ||
| - data.mode=streaming | ||
| - data.data_path=/scratchspace/data/train.jsonl | ||
| - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja | ||
| - training.output_dir=/scratchspace/dflash | ||
| - training.training_seq_len=4096 | ||
| - training.disable_tqdm=true | ||
| - training.ar_validate_steps=500000 | ||
| - training.answer_only_loss=true | ||
| - training.num_train_epochs=1 | ||
| - training.max_steps=5000 | ||
| - training.per_device_train_batch_size=2 | ||
| - training.learning_rate=1.2e-3 | ||
| - training.warmup_steps=100 | ||
| - training.logging_steps=100 | ||
| - training.save_steps=400 | ||
| - training.ddp_timeout=3600 | ||
| # dp_shard_size=1 stops main.py from creating a ParallelismConfig; the | ||
| # accelerate config below owns FSDP2 sharding. | ||
| - training.dp_shard_size=1 | ||
| - training.bf16=false | ||
| # vLLM container has no tensorboard -> init crash. | ||
| - training.report_to=none | ||
| - dflash.dflash_self_logit_distillation=true | ||
| - dflash.dflash_block_size=8 | ||
| - dflash.dflash_num_anchors=512 | ||
| - dflash.dflash_loss_decay_factor=4.0 | ||
| - dflash.dflash_architecture_config.num_hidden_layers=5 | ||
| - dflash.dflash_mask_token_id=200054 | ||
| # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48). | ||
| - dflash.dflash_export_rope_scaling.type=yarn | ||
| - dflash.dflash_export_rope_scaling.factor=48.0 | ||
| - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 | ||
| - dflash.dflash_export_rope_scaling.beta_fast=1.0 | ||
| - dflash.dflash_export_rope_scaling.beta_slow=1.0 | ||
| - dflash.dflash_export_rope_scaling.mscale=1.0 | ||
| - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 | ||
| environment: | ||
| - HF_MODEL_CKPT: <<global_vars.hf_model>> | ||
| - EAGLE_CAPTURE_IDS: "[2,17,32,47,62,64]" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: minimax seems to have 62 layers. What does 64 means here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch — MiniMax-M2.7 has 62 hidden layers, not 64. I miscounted from the prior DFlash M3 work (which is 64 layers). Fixed in b71297f: recalculated |
||
| - SERVE_NODES: "2" | ||
| - SERVE_TP: "4" | ||
| - STREAMING_NUM_WORKERS: "1" | ||
| - EXPORT_EXTRA_ARGS: "--trust_remote_code" | ||
| - SERVE_MAX_MODEL_LEN: "4160" | ||
| - SERVE_MAX_NUM_SEQS: "4" | ||
| - SERVE_GPU_MEM_UTIL: "0.8" | ||
| - SERVE_READY_TIMEOUT: "2400" | ||
| - SERVE_EXTRA_ARGS: "--trust-remote-code" | ||
| - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" | ||
| - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" | ||
| - OVERRIDE_TRANSFORMERS: "4.57.1" | ||
| - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml | ||
| - PATCH_FSDP2_BUFFERS_TF457: "1" | ||
| - MIXED_PRECISION: "no" | ||
| slurm_config: | ||
| _factory_: "slurm_factory" | ||
| nodes: 4 | ||
| segment: 4 | ||
| ntasks_per_node: 1 | ||
| gpus_per_node: 4 | ||
| container: vllm/vllm-openai:nightly | ||
|
|
||
| # Step 3: vLLM smoke test with DFlash speculative decoding | ||
| task_2: | ||
| script: common/specdec/vllm_smoke_test.sh | ||
| environment: | ||
| - HF_MODEL_CKPT: <<global_vars.hf_model>> | ||
| - DRAFT_CKPT_DIR: /scratchspace/dflash | ||
| - SPEC_METHOD: "dflash" | ||
| - NUM_SPEC_TOKENS: "7" | ||
| - TP_SIZE: "4" | ||
| slurm_config: | ||
| _factory_: "slurm_factory" | ||
| nodes: 1 | ||
| ntasks_per_node: 1 | ||
| gpus_per_node: 8 | ||
| container: vllm/vllm-openai:nightly | ||
Uh oh!
There was an error while loading. Please reload this page.