MLX: fix gated-delta (Qwen3.5) training and add fused backward kernels#754
MLX: fix gated-delta (Qwen3.5) training and add fused backward kernels#754Lyxot wants to merge 15 commits into
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
0d77b5d to
fa92b22
Compare
6b5c0e6 to
72f980b
Compare
patch_gated_delta only rebound mlx_lm.models.gated_delta, but every consumer (mlx_lm qwen3_5 / qwen3_next / kimi_linear, mlx_vlm qwen3_5) binds gated_delta_update via from-import at module load, so the memory-efficient VJP never reached any call site and training ran the raw sequential ops loop. Sweep already-imported consumer modules and rebind by identity; warn once about foreign implementations.
mx .at[:, t].add scatters read the update tensor with a wrong batch
stride on the Metal stream, silently corrupting d_q/d_k/d_v/d_g/d_beta
for every batch element except index 0. Collect per-step grads in lists
and stack instead; verified against plain autodiff at B in {2,3,4}.
The underlying Metal scatter indexing bug is fixed upstream in
ml-explore/mlx#3483 but is not in any release yet (latest: 0.31.2).
The list+stack form could revert to the scatter once a fixed mlx is
the minimum supported version, though it is also the cheaper op here.
The patch trigger was gated on 'qwen3_5' in model_type, but qwen3_next and kimi_linear share the same gated_delta_update call sites and hit the identical slow path. Detect gated-delta layers structurally (class name contains 'delta' + the A_log/dt_bias pair; Mamba/SSM mixers carry the parameters but never the name) and trigger on that instead.
72f980b to
17f2a28
Compare
Forward reuses mlx-lm's fused gated_delta_kernel per 64-step chunk
(boundary states fall out as chunk outputs); backward replays chunk
states into scratch (K1) then reverse-scans with atomic gradient
accumulation (K2), translating the validated ops BPTT into Metal.
Graph per layer drops from ~25k unrolled ops to ~100 kernel launches:
Qwen3.5-0.8B LoRA seq-2048 goes 151 -> ~900 tok/s at 4.2 GB peak (was
12.8 GB). Scalar gating, no mask, Dk%32==0; anything else falls back to
the ops VJP automatically. Gradient parity vs plain autodiff: <=1e-7
fp32 across GQA, multi-chunk, uneven tails, and B in {2,3}.
Extends K1/K2 to kimi_linear's per-column gating (g: [B, T, Hv, Dk]) via the same template-flag pattern mlx-lm uses for its forward kernel. The vectorized d_g is per-column atomic accumulation (simpler than the scalar simdgroup reduction); compute_g stays outside the kernel so d_A_log/d_dt_bias flow through autodiff unchanged. With this, every gated-delta architecture in mlx-lm trains on the kernel path; the ops VJP remains only as a guard fallback (no Metal, masked calls, unaligned head dims, or UNSLOTH_DISABLE_GD_KERNEL_VJP=1 — pass compile=False with it, mx.compile cannot fuse the unrolled ops graph). Gradient parity vs plain autodiff: <=2.4e-7 fp32 incl. vectorized+GQA; scalar path regression-checked bit-identical on Qwen3.5-0.8B.
Merges the binding-sweep/detection suite and the gradient-parity suite into tests/test_mlx_gated_delta_vjp.py. The torch shim now installs only when real mlx is absent (CI), parity tests skip without Metal, and the module-restore fixture keeps the file order-independent against shim-based sibling files. Folds source-patched/foreign-untouched asserts into the sweep test, merges detection cases, and parametrizes parity over both VJP implementations.
After rebasing onto unslothai#738: drop its legacy-layout branch (the binding sweep already rebinds mlx-vlm 0.4-0.5 from-imports), route its mlx-vlm >= 0.6 training branch through the fused-kernel dispatch (ops VJP under whole-step mx.compile would reintroduce the compile_fuse wedge the kernels eliminated), and teach the sweep to recognize the sibling patch instead of warning about a foreign implementation.
17f2a28 to
f3aa759
Compare
…iple The backward kernel's shared-memory pre-reduction has row 0 read every threadgroup row slot; a partial trailing threadgroup (Dv % 4 != 0) would read uninitialized slots. No current architecture hits this (Dv is 64/128 everywhere), so route such shapes to the ops VJP.
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds structural detection and improved patching for gated-delta (GDN) layers so training routes through memory-/compute-efficient VJP implementations, including a new fused Metal-kernel backward path and accompanying tests.
Changes:
- Trigger
patch_gated_delta()based on structural detection (model_has_gated_delta_layers) instead ofmodel_typeheuristics. - Fix incorrect per-timestep gradient scattering in the ops-based custom VJP by stacking per-step gradients.
- Expand
patch_gated_delta()to rebind stalefrom ... import gated_delta_updateconsumer bindings and add fused Metal-kernel VJP path + tests.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| unsloth_zoo/mlx/trainer.py | Switches patch trigger to structural gated-delta detection and refines patch ordering. |
| unsloth_zoo/mlx/compile.py | Adds structural gated-delta layer detection helpers used by trainer. |
| unsloth_zoo/gated_delta_vjp.py | Fixes VJP gradient accumulation, adds consumer rebinding sweep, and introduces fused Metal-kernel training VJP. |
| tests/test_mlx_gated_delta_vjp.py | Adds tests for structural detection, rebinding sweep, and gradient parity (incl. Metal-only kernel path). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@codex review |
Review feedback: tolerate a pre-existing patch flag without a recorded original (skip the identity sweep instead of AttributeError), raise a clear ValueError when gated_delta_kernel_efficient is called outside kernel support instead of failing on a None kernel, and parameterize the warned-foreign set annotation.
…x/mlx-gated-delta-stale-bindings # Conflicts: # unsloth_zoo/mlx/trainer.py
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cf29caf8bd
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
…tale-bindings # Conflicts: # unsloth_zoo/gated_delta_vjp.py
Summary
Gated-delta (linear attention) models — qwen3_5, qwen3_5_moe, qwen3_next, kimi_linear — could not be trained properly on MLX: the memory-efficient custom VJP never reached any call site, and the training recurrence ran as an unrolled per-token Python loop that is both slow and hostile to
mx.compile. This PR fixes the two bugs and replaces the recurrence backward with fused Metal kernels: Qwen3.5-0.8B LoRA training goes from ~140 tok/s with corrupt gradients (current main) to ~1,100 tok/s with correct gradients at 5x less memory.What changed
Rebind stale
gated_delta_updateconsumer imports —patch_gated_delta()only reboundmlx_lm.models.gated_delta.gated_delta_update, but every consumer (mlx_lm qwen3_5 / qwen3_next / kimi_linear, mlx_vlm qwen3_5) binds the symbol via from-import at module load, so the custom VJP had been dead code since it was introduced. An identity-matchedsys.modulessweep rebinds stale references; modules shipping their own implementation are never touched.Install the VJP for qwen3_next and kimi_linear too — the patch trigger was gated on
"qwen3_5" in model_type; it now detects gated-delta layers structurally (class name +A_log/dt_biaspair, deliberately excluding Mamba/SSM mixers which share the parameters but not the recurrence).Fused Metal backward kernels — forward reuses mlx-lm's fused
gated_delta_kernelper 64-step chunk (chunk-boundary states fall out as the chunk outputs); backward replays chunk states into scratch (K1), then reverse-scans with atomic gradient accumulation (K2). The per-layer graph drops from ~25k unrolled ops to ~100 kernel launches, which also makes whole-stepmx.compileviable. Calls outside kernel support (mask present, head dim not a multiple of 32, no Metal) fall back to the ops VJP automatically. Kernel pointer arithmetic is bounds-clean (no out-of-range pointer formation at chunk edges), and the reverse scan pre-reduces threadgroup rows in shared memory before issuing global atomics for d_q/d_k, cutting atomic traffic 4x (dispatch falls back to the ops VJP for head dims that would leave a partial threadgroup).Vectorized-gating support — extends the kernels to kimi_linear's per-column gating (
g: [B, T, Hv, Dk]) via the same template-flag pattern mlx-lm's forward kernel uses. Every gated-delta architecture in mlx-lm now trains on the kernel path.Unify with fix(mlx): Qwen3.5/3.6 VLM training — pass through new mlx-vlm attention kwargs, patch non-differentiable Metal kernels #738's VLM patch —
patch_gated_delta_vlm's mlx-vlm >= 0.6 branch now routes training through the fused-kernel dispatch (its previous ops-VJP routing under whole-stepmx.compilewould hit thecompile_fuseblow-up this PR eliminates); its legacy-layout branch is dropped (the sweep already rebinds mlx-vlm 0.4-0.5 from-imports); the sweep recognizes the sibling patch instead of warning about a foreign implementation.Tests — consolidated into
tests/test_mlx_gated_delta_vjp.py: binding-sweep and detection tests run on CI via the torch shim; gradient-parity tests (Metal-only) pin both the ops and kernel VJPs against plainmxautodiff at batch sizes 2-4, across GQA, multi-chunk, uneven tails, and both gating forms.Benchmark
Qwen3.5-0.8B, bf16 LoRA r=8, batch 2, seq <= 2048, gradient checkpointing on, 10 steps (~37.7k trained tokens) over a synthetic dataset of 16 long prompts (repeated reasoning filler in 5 length buckets up to the seq cap; identical rows for all four configurations), M-series 128 GB, mlx 0.31.2 / mlx-lm 0.31.3 / mlx-vlm 0.5.0:
compile=Falsemlx_vlm.lora(Upstream
mlx_vlm.lorameasured on 0.5.0; on 0.6.2 it crashes withPrimitive::vjp not implemented for CustomKernelbefore the first step.)Validation
mxautodiff (not vs the previous implementation — which is how the batch-stride bug evaded its original verification): <= 1e-7 fp32 / bf16-rounding bounded across GQA, multi-chunk, uneven tails, scalar + vectorized gating, batch 2-4, for both the ops and kernel VJPs.UnslothTraineron an RTX 4070 SUPER over 30 steps of OpenMathReasoning-mini, for qwen3_5-0.8b through these kernels as well as qwen3-0.6b / gemma3-270m controls.🤖 Generated with Claude Code