Skip to content

feat(moe): mxfp4-resident MoE experts for DeepSeek-V4-Flash LoRA#2548

Open
excepshenal wants to merge 24 commits into
NVIDIA-NeMo:mainfrom
excepshenal:dshen/feat/mxfp4-expert-lora
Open

feat(moe): mxfp4-resident MoE experts for DeepSeek-V4-Flash LoRA#2548
excepshenal wants to merge 24 commits into
NVIDIA-NeMo:mainfrom
excepshenal:dshen/feat/mxfp4-expert-lora

Conversation

@excepshenal

@excepshenal excepshenal commented Jun 12, 2026

Copy link
Copy Markdown

What does this PR do ?

What
Adds an mxfp4-resident storage path for the base experts of DeepSeek-V4-Flash during LoRA training. Experts are kept packed as fp4-e2m1 + e8m0 block scales (the format they already ship in) at steady state and dequantized on the fly inside the grouped-GEMM forward/backward, instead of being materialized in bf16. This lets DSV4-Flash LoRA fit and run on a single 8×H100 node.

Why
The experts are ~90% of DSV4-Flash's parameters (~277B → ~69 GiB/rank bf16 at ep_size=8). Dequantizing them to bf16 at load both balloons memory and makes load slow. Keeping them in their native mxfp4 form — only legal because LoRA freezes the base (no master copy / optimizer state needed for experts) — removes both costs.

How it works

  • Packed-resident storage. quantization/mxfp4.py provides pack/unpack + MXFP4GroupedMM, a grouped-GEMM autograd Function that saves only the packed weights and re-dequantizes in backward (no bf16 weight kept alive; no weight grad).
  • Passthrough load (no bf16 ever). The DSV4 state-dict adapter gains expert_storage_format="mxfp4": from_hf aggregates per-expert packed int8 + e8m0 scales straight into _packed/_scales params, and to_hf splits them back for the DCP planner. Packing is orthogonal to the expert-stack and gate‖up concat (it's along the contraction dim), so no unpacking is needed at load. Modules are built packed-at-init so the checkpoint loads directly into packed params.
  • EP integration. ExpertParallel shards the packed params on the expert dim; works on the torch_mm GroupedExperts path with dispatcher: torch.

Validation (8×H200, full 43-layer DSV4-Flash, ep_size=8, tilelang)

  • Numerical parity: step-0 LoRA loss matches a bf16-experts baseline (bit-identical at short seq; within bf16 noise at 4k packed seq). Unit tests assert forward/backward equivalence vs bf16 and bit-exact decode vs the existing fp4 dequant, incl. a real-checkpoint-gated test.
  • Memory: ~38 GiB/rank (mxfp4) vs ~84 GiB/rank (bf16) — ~2.2× less.
  • Load time: ~63 s (mxfp4) vs ~16.5 min (bf16, warm cache) — passthrough skips the load-time dequant+materialize.
  • Throughput (4096-token packed seq): ~6,567 tps (mxfp4) vs ~7,740 tps (bf16) — mxfp4 ~1.18× slower; the dequant overhead shrinks with longer seq (was ~2.2× at short seq) since dequant is a fixed cost amortized over a larger GEMM.

Known limitations / scope

  • EP required. mxfp4 needs the torch_mm GroupedExperts path with ep_size > 1; single-GPU is guarded with a clear error. DeepEP experts are not yet supported (the conversion skips them with a warning).
    ep_shard > 1 (multi-node where experts are FSDP-wrapped) still hits a loud FSDP-on-int8 error — needs packed experts excluded from FSDP before that config; single-node ep_size=8/ep_shard=1 avoids it.
  • Dequant is an unfused pure-torch kernel; further speedups (dequant-once-per-grad-accum, side-stream overlap, fused/Blackwell-native GEMM) are follow-ups.

Changelog

  • New: expert_storage_format-driven mxfp4 storage for base MoE experts (quantization/mxfp4.py, moe/quantized_experts.py, _peft/lora* for LoRA-targeted experts).
  • New: DSV4 adapter mxfp4 passthrough load/save (no bf16 expert materialization).
  • New: example config examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag_lora_mxfp4.yaml (8-GPU, ep_size=8, tilelang, mxfp4).
  • New PEFT config field peft.expert_weight_format: "bf16" | "mxfp4" (default bf16; mxfp4 requires experts: torch_mm).
  • Fix: skip random init for packed experts; set requires_grad at nn.Parameter construction in ExpertParallel (int8 EP-sharding blocker); guard single-GPU (mxfp4 requires EP).
  • Fix (general): clean process exit (destroy_process_group) and deadlock-safe distributed validation (lockstep across ranks) in recipes/llm/train_ft.py.
  • Perf: drop a redundant .contiguous() in the expert grouped GEMM (cuBLAS consumes the transposed view).

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

@excepshenal excepshenal requested a review from a team as a code owner June 12, 2026 22:27
@copy-pr-bot

copy-pr-bot Bot commented Jun 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@excepshenal excepshenal marked this pull request as draft June 12, 2026 22:27
Daniel and others added 3 commits June 12, 2026 22:57
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Keep frozen routed-expert base weights packed as fp4-e2m1 + e8m0 block
scales (the DeepSeek V4 Flash checkpoint format) at steady state during
LoRA training, dequantizing on the fly in forward and backward via a
custom grouped-GEMM autograd function that saves only the packed tensors.
Opt in with peft.expert_weight_format: mxfp4 (torch_mm experts backend
only). Packing happens at PEFT-swap time when weights are materialized,
or after checkpoint load for meta-initialized models.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
@excepshenal excepshenal force-pushed the dshen/feat/mxfp4-expert-lora branch from 7ebb84e to 2c7ff04 Compare June 12, 2026 22:57
@excepshenal excepshenal changed the title [model] feat: mxfp4-resident MoE experts for DeepSeek-V4-Flash LoRA feat(moe): mxfp4-resident MoE experts for DeepSeek-V4-Flash LoRA Jun 12, 2026
@akoumpa

akoumpa commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

/ok to test 2c7ff04

Daniel and others added 10 commits June 13, 2026 00:36
Extend mxfp4-resident expert storage to FROZEN (non-LoRA-targeted) routed
experts, which is where the memory win actually lands for LoRA-on-attention
recipes. Adds GroupedExpertsMXFP4 (frozen, packed base, dequant-on-the-fly)
and convert_frozen_experts_to_mxfp4(), invoked from the PEFT-application step
when peft.expert_weight_format=mxfp4. Packing of both frozen and LoRA-targeted
experts is deferred until after checkpoint load.

Format-specific pack/unpack/GEMM logic is factored into MXFP4ExpertStorageMixin
(shared by the frozen and LoRA variants) as the seam for a future int4 codec.

Only experts are quantized; all other weights stay bf16. v1 supports the
torch_mm GroupedExperts backend (DeepEP experts are skipped with a warning).
Adds an example DSV4-Flash LoRA+mxfp4 config.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Make MXFP4ExpertStorageMixin register its packed storage parameters at module
init (not only post-load), driven by a _PACKED_SUFFIXES list so the helper is
format-driven rather than hardcoding the two tensor names. Enables building the
model packed-at-init for the passthrough load path.

Signed-off-by: Daniel <dshen@crusoeenergy.com>
Phase B: load experts directly as packed fp4 so they are never materialized in
bf16, capping the load-time peak (the actual blocker for one-node DSV4-Flash).

- DeepSeekV4StateDictAdapter gains expert_storage_format='mxfp4': from_hf skips
  expert dequant and aggregates per-expert packed int8 + e8m0 scales into
  *_packed/*_scales keys. Concatenating gate||up along the output dim and
  stacking experts on dim 0 is layout-preserving (packing is along the
  contraction dim), so no unpacking is needed.
- GroupedExpertsMXFP4(passthrough=True) registers meta packed placeholders from
  config at init (via the init-capable register_packed_base_weight), so the
  packed checkpoint loads straight in with no bf16 storage.
- convert_frozen_experts_to_mxfp4(passthrough=...) plumbs the mode; kept opt-in
  so the validated bf16-then-pack path stays the default until the real
  checkpoint confirms Phase B end to end.

Tests: synthetic fp4 checkpoint proves passthrough emits packed keys (no bf16,
no orphaned scales) and decodes bit-identically to the bf16 dequant+aggregate
path; packed-at-init registers correct meta shapes/dtypes.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
…ation

Address organization: the quantization primitives belong alongside fp8/qat/qlora,
not under moe/. Mirrors how fp8.py holds the format/config while the layers that
use it live elsewhere.

- Move fp4_utils.py -> quantization/mxfp4.py.
- Export mxfp4 from quantization/__init__.
- GroupedExpertsMXFP4 / MXFP4ExpertStorageMixin stay in moe/ (GroupedExperts
  subclasses = model layers); GroupedExpertsLoRAMXFP4 stays in _peft (LoRA module,
  delegates all precision logic to the mixin).
- Add a real-checkpoint-gated test validating passthrough against
  /raid0/data/models/DeepSeek-V4-Flash: confirms int8+e8m0 layout and bit-exact
  decode vs the bf16 path on actual checkpoint bytes (self-skips when absent).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
- DeepSeekV4StateDictAdapter.to_hf now splits *_packed/*_scales params back into
  per-expert checkpoint keys (mirror of the from_hf aggregation), so the DCP
  loader can enumerate destination tensors; quantization placeholder step is
  bypassed for already-packed experts.
- infrastructure: when peft.expert_weight_format=mxfp4, set the adapter to
  passthrough mode and convert experts packed-at-init before load, so the fp4
  checkpoint loads straight into packed params (no bf16 experts).
- Tests: from_hf->to_hf round-trip recovers per-expert packed keys bit-exactly.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
_init_weights touched module.gate_and_up_projs, which passthrough
GroupedExpertsMXFP4 modules don't have (only *_packed/*_scales, filled from
the checkpoint). Skip init for mxfp4-resident experts. Found by an end-to-end
single-GPU run on real DeepSeek-V4-Flash weights (2 train steps, finite loss).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
nn.Parameter() defaults requires_grad=True, which raises 'only Tensors of
floating point dtype can require gradients' for the int8/e8m0 packed tensors of
mxfp4-resident experts before the subsequent requires_grad assignment runs.
Pass requires_grad at construction. Identical behavior for float params;
unblocks EP sharding of packed experts. Validated by a 2-GPU ep_size=2 run on
real DeepSeek-V4-Flash (both steps, finite loss, 55 GiB/rank vs 124 single-GPU).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Root-caused the single-GPU correctness bug: at world_size=1 the MoE parallelizer
is skipped, the packed e8m0 expert scales are never applied, and experts decode
to unscaled fp4 (~100x too large) — silently wrong (loss 19.60 vs bf16 17.28).
Diagnosed via dequant-norm debug: ep-sharded experts std=0.025 (correct) vs
world=1 std=2.48 (raw fp4 grid, scale==1).

mxfp4 is only correct when experts are EP-sharded (validated: ep_size=2 matches
the bf16 baseline within 0.04%). Guard world_size=1 with a clear error instead
of training on garbage. Multi-GPU ep_size>1 is the supported (and only sensible,
given model size) path. Removed debug instrumentation.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Two issues made training appear to hang at the end:

1. No process-group teardown. run_train_validation_loop returned but the
   distributed process group was never destroyed, so NCCL / the elastic agent
   waited and the process never exited. Add a best-effort barrier +
   destroy_process_group in main()'s finally.

2. End-of-training validation could deadlock under expert parallelism. The MoE
   expert forward issues EP collectives; if DP ranks see uneven validation
   shard sizes they call those collectives a different number of times and
   hang. Drive the validation loop by a global-MIN "does every rank still have
   a batch?" all-reduce so all ranks run the same number of forwards.

Validated on 2xGPU DSV4-Flash mxfp4 LoRA: validation completes and the process
exits cleanly (exit 0). Note: validation still runs over the full val set at the
last step (is_ckpt_step -> is_last_step); bounding it (eval_iters cap) is a
possible follow-up for snappier end-of-training.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
MXFP4GroupedMM.forward dequantized to [E,N,K] then did
.transpose(-2,-1).contiguous() to feed torch._grouped_mm — a full bf16 weight
copy per forward. torch._grouped_mm accepts the transposed view directly
(cuBLAS transB, verified bit-identical), so pass the view and skip the copy.

Microbench (E=32,N=K=4096): the op goes 4.1ms -> 0.4ms (copy was pure waste).
End-to-end DSV4-Flash 8-GPU ep8, 4096-token packed seq: ~4630 -> ~4970 tps
(+7%); mxfp4 slowdown vs bf16 improves from ~1.67x to ~1.56x. Loss unchanged.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
@akoumpa

akoumpa commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

/ok to test 6463d77

Daniel Shen and others added 6 commits June 15, 2026 05:14
Add GroupedExpertsDeepEPMXFP4 (frozen) and GroupedExpertsDeepEPLoRAMXFP4
(LoRA-on-experts) so DeepSeek V4 routed experts can stay packed as fp4-e2m1 +
e8m0 block scales while using the DeepEP fused all-to-all token dispatch. mxfp4
only changes the two post-dispatch grouped GEMMs (dequant on the fly via
MXFP4GroupedMM); dispatch/combine are unchanged. Requires backend.experts=torch_mm.

- Lift the mxfp4+DeepEP guards in patch_moe_module and
  convert_frozen_experts_to_mxfp4 (TE experts still unsupported).
- Hoist _init_packed_placeholders into MXFP4ExpertStorageMixin so the torch and
  DeepEP frozen variants share the passthrough placeholder path.
- Add example recipe deepseek_v4_flash_hellaswag_lora_mxfp4_deepep.yaml.
- Add unit tests: guard lifting + GEMM-substitution numerics via a mock dispatcher.

Validated on 8xH200 EP=8: 12-step train + validation + checkpoint, finite
loss/grad_norm, ~40 GiB/rank.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Daniel Shen <dshen@crusoe.ai>
Makes routed-expert (and shared-expert) LoRA actually work with mxfp4-resident
base weights on the full DeepSeek-V4-Flash, fixing three issues that blocked it:

- Load wiring: GroupedExperts*LoRAMXFP4 now support a passthrough mode that
  registers packed base placeholders at init (instead of deferred bf16-load-then-pack),
  so the packed fp4 checkpoint loads straight in via _aggregate_experts_packed. This
  avoids the bf16 _aggregate_experts re-stack that OOM'd at load (~137 GiB/rank) and
  keeps the steady-state footprint at the packed ~49 GiB/rank. Threaded through
  patch_moe_module / apply_lora_to_linear_modules; infra keeps the adapter in packed
  mode for both frozen and LoRA experts.
- Adapter dtype: the LoRA grouped GEMMs now cast adapters to the activation dtype
  (GroupedExpertsDeepEP allocates its base, hence adapter sizing, as fp32 when no
  backend dtype is set), fixing "mat1 and mat2 have the same dtype, BFloat16 != float".
- Gate NaN guard: clamp_min(1e-12) under the sqrtsoftplus gate sqrt (both the generic
  Gate and the DSV4 hash gate). softplus underflows to 0.0 for very negative logits and
  sqrt'(0)=inf makes the backward NaN; the clamp bounds the gradient with a negligible
  forward change.
- Enable expert LoRA in both deepseek_v4 mxfp4 recipes (target *mlp.experts and
  *shared_experts.*proj).

Validated on 8xH200 EP=8 (packed 4096): routed+shared expert LoRA trains 15 steps,
loss 8.96->5.72 monotonic, grad_norm finite, no NaN, 48.9 GiB/rank, val loss 2.32.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Daniel Shen <dshen@crusoe.ai>
…inear-CE

Fused linear-CE for DSV4 is upstream (NVIDIA-NeMo#2397); only the DSV4 dtype-cast remains.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Daniel Shen <dshen@crusoe.ai>
…sEntropy

Both DSV4 mxfp4 LoRA recipes now use the fused linear cross-entropy so the
[seq, 129280] logits are never materialized, removing the ~16 GiB fp32 logits
spike (single-node context ceiling ~30k -> ~36-38k tokens on one 8xH200).
Measured fit: packed 32768 at 115 GiB/rank (vs OOM ~137 GiB with MaskedCrossEntropy).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Daniel Shen <dshen@crusoe.ai>
Set peft.lora_dtype=float32 in both DSV4 mxfp4 LoRA recipes. With FSDP2's default
MixedPrecisionPolicy(param_dtype=bf16), this keeps fp32 master weights + fp32 AdamW
state for the adapters (stability against small-update swamping, matching HF PEFT's
autocast_adapter_dtype default) while the adapter matmuls still run in bf16 (FSDP casts
the all-gathered params to param_dtype). Adapters are tiny so the fp32 optimizer-state
cost is negligible.

Validated 8xH200 EP=8 packed-4096 (fused CE + routed/shared expert LoRA + fp32 adapters):
2 steps, loss 8.96->8.77, finite grad_norm, no dtype mismatch in the triton/grouped-GEMM
LoRA paths.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Daniel Shen <dshen@crusoe.ai>
…-param requires_grad

Two L0 unit tests were stale relative to earlier branch code changes:

- test_grouped_experts_deepep_token_dispatcher_init asserted init_token_dispatcher
  eagerly calls _init_deepep_buffer, but buffer allocation is now lazy (deferred to
  FusedDispatch.forward) — the revert that fixed the single-node load-time OOM. Assert
  it is NOT called.
- ExpertParallel._partition_fn now constructs nn.Parameter(..., requires_grad=...) so
  non-floating packed mxfp4 params (int8 / e8m0) don't trip the default requires_grad=True.
  The test's stub Parameter didn't accept/store requires_grad; add it (also unblocks the
  requires_grad-preservation test).

Both fixes verified: tests/unit_tests/moe now 450 passed, 0 failed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Daniel Shen <dshen@crusoe.ai>
@excepshenal excepshenal force-pushed the dshen/feat/mxfp4-expert-lora branch 2 times, most recently from 9822d2e to f0c9322 Compare June 15, 2026 06:51
@excepshenal excepshenal marked this pull request as ready for review June 15, 2026 06:51
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label Jun 15, 2026
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Signed-off-by: Daniel <dshen@crusoeenergy.com>
@excepshenal excepshenal marked this pull request as draft June 16, 2026 23:15
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label Jun 17, 2026
@excepshenal excepshenal marked this pull request as ready for review June 17, 2026 22:24
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants