feat(moe): mxfp4-resident MoE experts for DeepSeek-V4-Flash LoRA#2548
Open
excepshenal wants to merge 24 commits into
Open
feat(moe): mxfp4-resident MoE experts for DeepSeek-V4-Flash LoRA#2548excepshenal wants to merge 24 commits into
excepshenal wants to merge 24 commits into
Conversation
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Keep frozen routed-expert base weights packed as fp4-e2m1 + e8m0 block scales (the DeepSeek V4 Flash checkpoint format) at steady state during LoRA training, dequantizing on the fly in forward and backward via a custom grouped-GEMM autograd function that saves only the packed tensors. Opt in with peft.expert_weight_format: mxfp4 (torch_mm experts backend only). Packing happens at PEFT-swap time when weights are materialized, or after checkpoint load for meta-initialized models. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
7ebb84e to
2c7ff04
Compare
Contributor
|
/ok to test 2c7ff04 |
Extend mxfp4-resident expert storage to FROZEN (non-LoRA-targeted) routed experts, which is where the memory win actually lands for LoRA-on-attention recipes. Adds GroupedExpertsMXFP4 (frozen, packed base, dequant-on-the-fly) and convert_frozen_experts_to_mxfp4(), invoked from the PEFT-application step when peft.expert_weight_format=mxfp4. Packing of both frozen and LoRA-targeted experts is deferred until after checkpoint load. Format-specific pack/unpack/GEMM logic is factored into MXFP4ExpertStorageMixin (shared by the frozen and LoRA variants) as the seam for a future int4 codec. Only experts are quantized; all other weights stay bf16. v1 supports the torch_mm GroupedExperts backend (DeepEP experts are skipped with a warning). Adds an example DSV4-Flash LoRA+mxfp4 config. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
Make MXFP4ExpertStorageMixin register its packed storage parameters at module init (not only post-load), driven by a _PACKED_SUFFIXES list so the helper is format-driven rather than hardcoding the two tensor names. Enables building the model packed-at-init for the passthrough load path. Signed-off-by: Daniel <dshen@crusoeenergy.com>
Phase B: load experts directly as packed fp4 so they are never materialized in bf16, capping the load-time peak (the actual blocker for one-node DSV4-Flash). - DeepSeekV4StateDictAdapter gains expert_storage_format='mxfp4': from_hf skips expert dequant and aggregates per-expert packed int8 + e8m0 scales into *_packed/*_scales keys. Concatenating gate||up along the output dim and stacking experts on dim 0 is layout-preserving (packing is along the contraction dim), so no unpacking is needed. - GroupedExpertsMXFP4(passthrough=True) registers meta packed placeholders from config at init (via the init-capable register_packed_base_weight), so the packed checkpoint loads straight in with no bf16 storage. - convert_frozen_experts_to_mxfp4(passthrough=...) plumbs the mode; kept opt-in so the validated bf16-then-pack path stays the default until the real checkpoint confirms Phase B end to end. Tests: synthetic fp4 checkpoint proves passthrough emits packed keys (no bf16, no orphaned scales) and decodes bit-identically to the bf16 dequant+aggregate path; packed-at-init registers correct meta shapes/dtypes. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
…ation Address organization: the quantization primitives belong alongside fp8/qat/qlora, not under moe/. Mirrors how fp8.py holds the format/config while the layers that use it live elsewhere. - Move fp4_utils.py -> quantization/mxfp4.py. - Export mxfp4 from quantization/__init__. - GroupedExpertsMXFP4 / MXFP4ExpertStorageMixin stay in moe/ (GroupedExperts subclasses = model layers); GroupedExpertsLoRAMXFP4 stays in _peft (LoRA module, delegates all precision logic to the mixin). - Add a real-checkpoint-gated test validating passthrough against /raid0/data/models/DeepSeek-V4-Flash: confirms int8+e8m0 layout and bit-exact decode vs the bf16 path on actual checkpoint bytes (self-skips when absent). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
- DeepSeekV4StateDictAdapter.to_hf now splits *_packed/*_scales params back into per-expert checkpoint keys (mirror of the from_hf aggregation), so the DCP loader can enumerate destination tensors; quantization placeholder step is bypassed for already-packed experts. - infrastructure: when peft.expert_weight_format=mxfp4, set the adapter to passthrough mode and convert experts packed-at-init before load, so the fp4 checkpoint loads straight into packed params (no bf16 experts). - Tests: from_hf->to_hf round-trip recovers per-expert packed keys bit-exactly. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
_init_weights touched module.gate_and_up_projs, which passthrough GroupedExpertsMXFP4 modules don't have (only *_packed/*_scales, filled from the checkpoint). Skip init for mxfp4-resident experts. Found by an end-to-end single-GPU run on real DeepSeek-V4-Flash weights (2 train steps, finite loss). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
nn.Parameter() defaults requires_grad=True, which raises 'only Tensors of floating point dtype can require gradients' for the int8/e8m0 packed tensors of mxfp4-resident experts before the subsequent requires_grad assignment runs. Pass requires_grad at construction. Identical behavior for float params; unblocks EP sharding of packed experts. Validated by a 2-GPU ep_size=2 run on real DeepSeek-V4-Flash (both steps, finite loss, 55 GiB/rank vs 124 single-GPU). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
Root-caused the single-GPU correctness bug: at world_size=1 the MoE parallelizer is skipped, the packed e8m0 expert scales are never applied, and experts decode to unscaled fp4 (~100x too large) — silently wrong (loss 19.60 vs bf16 17.28). Diagnosed via dequant-norm debug: ep-sharded experts std=0.025 (correct) vs world=1 std=2.48 (raw fp4 grid, scale==1). mxfp4 is only correct when experts are EP-sharded (validated: ep_size=2 matches the bf16 baseline within 0.04%). Guard world_size=1 with a clear error instead of training on garbage. Multi-GPU ep_size>1 is the supported (and only sensible, given model size) path. Removed debug instrumentation. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
Two issues made training appear to hang at the end: 1. No process-group teardown. run_train_validation_loop returned but the distributed process group was never destroyed, so NCCL / the elastic agent waited and the process never exited. Add a best-effort barrier + destroy_process_group in main()'s finally. 2. End-of-training validation could deadlock under expert parallelism. The MoE expert forward issues EP collectives; if DP ranks see uneven validation shard sizes they call those collectives a different number of times and hang. Drive the validation loop by a global-MIN "does every rank still have a batch?" all-reduce so all ranks run the same number of forwards. Validated on 2xGPU DSV4-Flash mxfp4 LoRA: validation completes and the process exits cleanly (exit 0). Note: validation still runs over the full val set at the last step (is_ckpt_step -> is_last_step); bounding it (eval_iters cap) is a possible follow-up for snappier end-of-training. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
MXFP4GroupedMM.forward dequantized to [E,N,K] then did .transpose(-2,-1).contiguous() to feed torch._grouped_mm — a full bf16 weight copy per forward. torch._grouped_mm accepts the transposed view directly (cuBLAS transB, verified bit-identical), so pass the view and skip the copy. Microbench (E=32,N=K=4096): the op goes 4.1ms -> 0.4ms (copy was pure waste). End-to-end DSV4-Flash 8-GPU ep8, 4096-token packed seq: ~4630 -> ~4970 tps (+7%); mxfp4 slowdown vs bf16 improves from ~1.67x to ~1.56x. Loss unchanged. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
Contributor
|
/ok to test 6463d77 |
Add GroupedExpertsDeepEPMXFP4 (frozen) and GroupedExpertsDeepEPLoRAMXFP4 (LoRA-on-experts) so DeepSeek V4 routed experts can stay packed as fp4-e2m1 + e8m0 block scales while using the DeepEP fused all-to-all token dispatch. mxfp4 only changes the two post-dispatch grouped GEMMs (dequant on the fly via MXFP4GroupedMM); dispatch/combine are unchanged. Requires backend.experts=torch_mm. - Lift the mxfp4+DeepEP guards in patch_moe_module and convert_frozen_experts_to_mxfp4 (TE experts still unsupported). - Hoist _init_packed_placeholders into MXFP4ExpertStorageMixin so the torch and DeepEP frozen variants share the passthrough placeholder path. - Add example recipe deepseek_v4_flash_hellaswag_lora_mxfp4_deepep.yaml. - Add unit tests: guard lifting + GEMM-substitution numerics via a mock dispatcher. Validated on 8xH200 EP=8: 12-step train + validation + checkpoint, finite loss/grad_norm, ~40 GiB/rank. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
Makes routed-expert (and shared-expert) LoRA actually work with mxfp4-resident base weights on the full DeepSeek-V4-Flash, fixing three issues that blocked it: - Load wiring: GroupedExperts*LoRAMXFP4 now support a passthrough mode that registers packed base placeholders at init (instead of deferred bf16-load-then-pack), so the packed fp4 checkpoint loads straight in via _aggregate_experts_packed. This avoids the bf16 _aggregate_experts re-stack that OOM'd at load (~137 GiB/rank) and keeps the steady-state footprint at the packed ~49 GiB/rank. Threaded through patch_moe_module / apply_lora_to_linear_modules; infra keeps the adapter in packed mode for both frozen and LoRA experts. - Adapter dtype: the LoRA grouped GEMMs now cast adapters to the activation dtype (GroupedExpertsDeepEP allocates its base, hence adapter sizing, as fp32 when no backend dtype is set), fixing "mat1 and mat2 have the same dtype, BFloat16 != float". - Gate NaN guard: clamp_min(1e-12) under the sqrtsoftplus gate sqrt (both the generic Gate and the DSV4 hash gate). softplus underflows to 0.0 for very negative logits and sqrt'(0)=inf makes the backward NaN; the clamp bounds the gradient with a negligible forward change. - Enable expert LoRA in both deepseek_v4 mxfp4 recipes (target *mlp.experts and *shared_experts.*proj). Validated on 8xH200 EP=8 (packed 4096): routed+shared expert LoRA trains 15 steps, loss 8.96->5.72 monotonic, grad_norm finite, no NaN, 48.9 GiB/rank, val loss 2.32. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
…inear-CE Fused linear-CE for DSV4 is upstream (NVIDIA-NeMo#2397); only the DSV4 dtype-cast remains. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
…sEntropy Both DSV4 mxfp4 LoRA recipes now use the fused linear cross-entropy so the [seq, 129280] logits are never materialized, removing the ~16 GiB fp32 logits spike (single-node context ceiling ~30k -> ~36-38k tokens on one 8xH200). Measured fit: packed 32768 at 115 GiB/rank (vs OOM ~137 GiB with MaskedCrossEntropy). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
Set peft.lora_dtype=float32 in both DSV4 mxfp4 LoRA recipes. With FSDP2's default MixedPrecisionPolicy(param_dtype=bf16), this keeps fp32 master weights + fp32 AdamW state for the adapters (stability against small-update swamping, matching HF PEFT's autocast_adapter_dtype default) while the adapter matmuls still run in bf16 (FSDP casts the all-gathered params to param_dtype). Adapters are tiny so the fp32 optimizer-state cost is negligible. Validated 8xH200 EP=8 packed-4096 (fused CE + routed/shared expert LoRA + fp32 adapters): 2 steps, loss 8.96->8.77, finite grad_norm, no dtype mismatch in the triton/grouped-GEMM LoRA paths. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
…-param requires_grad Two L0 unit tests were stale relative to earlier branch code changes: - test_grouped_experts_deepep_token_dispatcher_init asserted init_token_dispatcher eagerly calls _init_deepep_buffer, but buffer allocation is now lazy (deferred to FusedDispatch.forward) — the revert that fixed the single-node load-time OOM. Assert it is NOT called. - ExpertParallel._partition_fn now constructs nn.Parameter(..., requires_grad=...) so non-floating packed mxfp4 params (int8 / e8m0) don't trip the default requires_grad=True. The test's stub Parameter didn't accept/store requires_grad; add it (also unblocks the requires_grad-preservation test). Both fixes verified: tests/unit_tests/moe now 450 passed, 0 failed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
9822d2e to
f0c9322
Compare
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Signed-off-by: Daniel <dshen@crusoeenergy.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
What
Adds an mxfp4-resident storage path for the base experts of DeepSeek-V4-Flash during LoRA training. Experts are kept packed as fp4-e2m1 + e8m0 block scales (the format they already ship in) at steady state and dequantized on the fly inside the grouped-GEMM forward/backward, instead of being materialized in bf16. This lets DSV4-Flash LoRA fit and run on a single 8×H100 node.
Why
The experts are ~90% of DSV4-Flash's parameters (~277B → ~69 GiB/rank bf16 at ep_size=8). Dequantizing them to bf16 at load both balloons memory and makes load slow. Keeping them in their native mxfp4 form — only legal because LoRA freezes the base (no master copy / optimizer state needed for experts) — removes both costs.
How it works
Validation (8×H200, full 43-layer DSV4-Flash, ep_size=8, tilelang)
Known limitations / scope
ep_shard > 1 (multi-node where experts are FSDP-wrapped) still hits a loud FSDP-on-int8 error — needs packed experts excluded from FSDP before that config; single-node ep_size=8/ep_shard=1 avoids it.
Changelog
Before your PR is "Ready for review"
Pre checks: