Skip to content

MLX: fix gated-delta (Qwen3.5) training and add fused backward kernels#754

Open
Lyxot wants to merge 15 commits into
unslothai:mainfrom
Lyxot:fix/mlx-gated-delta-stale-bindings
Open

MLX: fix gated-delta (Qwen3.5) training and add fused backward kernels#754
Lyxot wants to merge 15 commits into
unslothai:mainfrom
Lyxot:fix/mlx-gated-delta-stale-bindings

Conversation

@Lyxot

@Lyxot Lyxot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Summary

Gated-delta (linear attention) models — qwen3_5, qwen3_5_moe, qwen3_next, kimi_linear — could not be trained properly on MLX: the memory-efficient custom VJP never reached any call site, and the training recurrence ran as an unrolled per-token Python loop that is both slow and hostile to mx.compile. This PR fixes the two bugs and replaces the recurrence backward with fused Metal kernels: Qwen3.5-0.8B LoRA training goes from ~140 tok/s with corrupt gradients (current main) to ~1,100 tok/s with correct gradients at 5x less memory.

What changed

  • Rebind stale gated_delta_update consumer importspatch_gated_delta() only rebound mlx_lm.models.gated_delta.gated_delta_update, but every consumer (mlx_lm qwen3_5 / qwen3_next / kimi_linear, mlx_vlm qwen3_5) binds the symbol via from-import at module load, so the custom VJP had been dead code since it was introduced. An identity-matched sys.modules sweep rebinds stale references; modules shipping their own implementation are never touched.

  • Install the VJP for qwen3_next and kimi_linear too — the patch trigger was gated on "qwen3_5" in model_type; it now detects gated-delta layers structurally (class name + A_log/dt_bias pair, deliberately excluding Mamba/SSM mixers which share the parameters but not the recurrence).

  • Fused Metal backward kernels — forward reuses mlx-lm's fused gated_delta_kernel per 64-step chunk (chunk-boundary states fall out as the chunk outputs); backward replays chunk states into scratch (K1), then reverse-scans with atomic gradient accumulation (K2). The per-layer graph drops from ~25k unrolled ops to ~100 kernel launches, which also makes whole-step mx.compile viable. Calls outside kernel support (mask present, head dim not a multiple of 32, no Metal) fall back to the ops VJP automatically. Kernel pointer arithmetic is bounds-clean (no out-of-range pointer formation at chunk edges), and the reverse scan pre-reduces threadgroup rows in shared memory before issuing global atomics for d_q/d_k, cutting atomic traffic 4x (dispatch falls back to the ops VJP for head dims that would leave a partial threadgroup).

  • Vectorized-gating support — extends the kernels to kimi_linear's per-column gating (g: [B, T, Hv, Dk]) via the same template-flag pattern mlx-lm's forward kernel uses. Every gated-delta architecture in mlx-lm now trains on the kernel path.

  • Unify with fix(mlx): Qwen3.5/3.6 VLM training — pass through new mlx-vlm attention kwargs, patch non-differentiable Metal kernels #738's VLM patchpatch_gated_delta_vlm's mlx-vlm >= 0.6 branch now routes training through the fused-kernel dispatch (its previous ops-VJP routing under whole-step mx.compile would hit the compile_fuse blow-up this PR eliminates); its legacy-layout branch is dropped (the sweep already rebinds mlx-vlm 0.4-0.5 from-imports); the sweep recognizes the sibling patch instead of warning about a foreign implementation.

  • Tests — consolidated into tests/test_mlx_gated_delta_vjp.py: binding-sweep and detection tests run on CI via the torch shim; gradient-parity tests (Metal-only) pin both the ops and kernel VJPs against plain mx autodiff at batch sizes 2-4, across GQA, multi-chunk, uneven tails, and both gating forms.

Benchmark

Qwen3.5-0.8B, bf16 LoRA r=8, batch 2, seq <= 2048, gradient checkpointing on, 10 steps (~37.7k trained tokens) over a synthetic dataset of 16 long prompts (repeated reasoning filler in 5 length buckets up to the seq cap; identical rows for all four configurations), M-series 128 GB, mlx 0.31.2 / mlx-lm 0.31.3 / mlx-vlm 0.5.0:

configuration avg tok/s wall time peak GPU
main 141 272 s 37.5 GB
main, compile=False 136 277 s 12.8 GB
upstream mlx_vlm.lora 89 ~425 s 38.9 GB
this PR 1,244 32 s 7.4 GB

(Upstream mlx_vlm.lora measured on 0.5.0; on 0.6.2 it crashes with Primitive::vjp not implemented for CustomKernel before the first step.)

Validation

  • Gradient parity vs plain mx autodiff (not vs the previous implementation — which is how the batch-stride bug evaded its original verification): <= 1e-7 fp32 / bf16-rounding bounded across GQA, multi-chunk, uneven tails, scalar + vectorized gating, batch 2-4, for both the ops and kernel VJPs.
  • Cross-framework parity vs CUDA: the trainer-owned LoRA harness (fingerprint-verified identical token batches, trainer-owned AdamW + matched LR schedule, TF32 off) shows mean per-step loss difference < 0.005 against UnslothTrainer on an RTX 4070 SUPER over 30 steps of OpenMathReasoning-mini, for qwen3_5-0.8b through these kernels as well as qwen3-0.6b / gemma3-270m controls.
  • Real-model runs: Qwen3.5-0.8B end-to-end (training, save, cold-process reload with loss parity); Kimi-Linear-48B-A3B 4-bit QLoRA (binding sweep connects, vectorized kernels under quantized projections, finite decreasing loss at 29.2 GB peak).
  • Atomic accumulation makes low-order gradient bits nondeterministic across runs, matching Metal reduction-order behavior elsewhere in the backend.

🤖 Generated with Claude Code

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Lyxot Lyxot force-pushed the fix/mlx-gated-delta-stale-bindings branch from 0d77b5d to fa92b22 Compare June 11, 2026 05:57
@Lyxot Lyxot changed the title MLX: fix gated-delta (Qwen3.5) training and add fused backward kernels (~1000 tok/s) MLX: fix gated-delta (Qwen3.5) training and add fused backward kernels Jun 11, 2026
@Lyxot Lyxot force-pushed the fix/mlx-gated-delta-stale-bindings branch 3 times, most recently from 6b5c0e6 to 72f980b Compare June 12, 2026 08:40
Lyxot added 3 commits June 12, 2026 16:52
patch_gated_delta only rebound mlx_lm.models.gated_delta, but every
consumer (mlx_lm qwen3_5 / qwen3_next / kimi_linear, mlx_vlm qwen3_5)
binds gated_delta_update via from-import at module load, so the
memory-efficient VJP never reached any call site and training ran the
raw sequential ops loop. Sweep already-imported consumer modules and
rebind by identity; warn once about foreign implementations.
mx .at[:, t].add scatters read the update tensor with a wrong batch
stride on the Metal stream, silently corrupting d_q/d_k/d_v/d_g/d_beta
for every batch element except index 0. Collect per-step grads in lists
and stack instead; verified against plain autodiff at B in {2,3,4}.

The underlying Metal scatter indexing bug is fixed upstream in
ml-explore/mlx#3483 but is not in any release yet (latest: 0.31.2).
The list+stack form could revert to the scatter once a fixed mlx is
the minimum supported version, though it is also the cheaper op here.
The patch trigger was gated on 'qwen3_5' in model_type, but qwen3_next
and kimi_linear share the same gated_delta_update call sites and hit the
identical slow path. Detect gated-delta layers structurally (class name
contains 'delta' + the A_log/dt_bias pair; Mamba/SSM mixers carry the
parameters but never the name) and trigger on that instead.
@Lyxot Lyxot force-pushed the fix/mlx-gated-delta-stale-bindings branch from 72f980b to 17f2a28 Compare June 12, 2026 08:56
Lyxot added 5 commits June 12, 2026 17:23
Forward reuses mlx-lm's fused gated_delta_kernel per 64-step chunk
(boundary states fall out as chunk outputs); backward replays chunk
states into scratch (K1) then reverse-scans with atomic gradient
accumulation (K2), translating the validated ops BPTT into Metal.
Graph per layer drops from ~25k unrolled ops to ~100 kernel launches:
Qwen3.5-0.8B LoRA seq-2048 goes 151 -> ~900 tok/s at 4.2 GB peak (was
12.8 GB). Scalar gating, no mask, Dk%32==0; anything else falls back to
the ops VJP automatically. Gradient parity vs plain autodiff: <=1e-7
fp32 across GQA, multi-chunk, uneven tails, and B in {2,3}.
Extends K1/K2 to kimi_linear's per-column gating (g: [B, T, Hv, Dk])
via the same template-flag pattern mlx-lm uses for its forward kernel.
The vectorized d_g is per-column atomic accumulation (simpler than the
scalar simdgroup reduction); compute_g stays outside the kernel so
d_A_log/d_dt_bias flow through autodiff unchanged. With this, every
gated-delta architecture in mlx-lm trains on the kernel path; the ops
VJP remains only as a guard fallback (no Metal, masked calls, unaligned
head dims, or UNSLOTH_DISABLE_GD_KERNEL_VJP=1 — pass compile=False with
it, mx.compile cannot fuse the unrolled ops graph).
Gradient parity vs plain autodiff: <=2.4e-7 fp32 incl. vectorized+GQA;
scalar path regression-checked bit-identical on Qwen3.5-0.8B.
Merges the binding-sweep/detection suite and the gradient-parity suite
into tests/test_mlx_gated_delta_vjp.py. The torch shim now installs only
when real mlx is absent (CI), parity tests skip without Metal, and the
module-restore fixture keeps the file order-independent against
shim-based sibling files. Folds source-patched/foreign-untouched asserts
into the sweep test, merges detection cases, and parametrizes parity
over both VJP implementations.
After rebasing onto unslothai#738: drop its legacy-layout branch (the binding
sweep already rebinds mlx-vlm 0.4-0.5 from-imports), route its
mlx-vlm >= 0.6 training branch through the fused-kernel dispatch (ops
VJP under whole-step mx.compile would reintroduce the compile_fuse
wedge the kernels eliminated), and teach the sweep to recognize the
sibling patch instead of warning about a foreign implementation.
@Lyxot Lyxot force-pushed the fix/mlx-gated-delta-stale-bindings branch from 17f2a28 to f3aa759 Compare June 12, 2026 09:30
Lyxot added 3 commits June 12, 2026 20:01
…iple

The backward kernel's shared-memory pre-reduction has row 0 read every
threadgroup row slot; a partial trailing threadgroup (Dv % 4 != 0)
would read uninitialized slots. No current architecture hits this
(Dv is 64/128 everywhere), so route such shapes to the ops VJP.
@Lyxot Lyxot marked this pull request as ready for review June 12, 2026 15:00
Copilot AI review requested due to automatic review settings June 12, 2026 15:00

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds structural detection and improved patching for gated-delta (GDN) layers so training routes through memory-/compute-efficient VJP implementations, including a new fused Metal-kernel backward path and accompanying tests.

Changes:

  • Trigger patch_gated_delta() based on structural detection (model_has_gated_delta_layers) instead of model_type heuristics.
  • Fix incorrect per-timestep gradient scattering in the ops-based custom VJP by stacking per-step gradients.
  • Expand patch_gated_delta() to rebind stale from ... import gated_delta_update consumer bindings and add fused Metal-kernel VJP path + tests.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
unsloth_zoo/mlx/trainer.py Switches patch trigger to structural gated-delta detection and refines patch ordering.
unsloth_zoo/mlx/compile.py Adds structural gated-delta layer detection helpers used by trainer.
unsloth_zoo/gated_delta_vjp.py Fixes VJP gradient accumulation, adds consumer rebinding sweep, and introduces fused Metal-kernel training VJP.
tests/test_mlx_gated_delta_vjp.py Adds tests for structural detection, rebinding sweep, and gradient parity (incl. Metal-only kernel path).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread unsloth_zoo/gated_delta_vjp.py Outdated
Comment thread unsloth_zoo/gated_delta_vjp.py
Comment thread unsloth_zoo/gated_delta_vjp.py Outdated
@Lyxot

Lyxot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

@codex review

Lyxot added 2 commits June 12, 2026 23:17
Review feedback: tolerate a pre-existing patch flag without a recorded
original (skip the identity sweep instead of AttributeError), raise a
clear ValueError when gated_delta_kernel_efficient is called outside
kernel support instead of failing on a None kernel, and parameterize
the warned-foreign set annotation.
…x/mlx-gated-delta-stale-bindings

# Conflicts:
#	unsloth_zoo/mlx/trainer.py

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cf29caf8bd

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/gated_delta_vjp.py
…tale-bindings

# Conflicts:
#	unsloth_zoo/gated_delta_vjp.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants