fix(mlx): wrong gated-delta grads for batch rows past the first#776
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR addresses incorrect per-step gradient accumulation in _chunked_vjp by avoiding MLX .at[:, t].add(...) scatter updates and instead collecting step gradients in Python lists and stacking at the end.
Changes:
- Replace
.at[:, t].add(...)gradient accumulation with per-step list collection. - Stack collected per-step gradients into
(B, T, ...)tensors after the backward loop. - Add an inline comment explaining the MLX scatter bug and linking the upstream fix.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Code Review
This pull request modifies unsloth_zoo/gated_delta_vjp.py to collect per-step gradients in lists and stack them after the loop, rather than using in-place scatter additions (.at[:, t].add). This change works around an upstream MLX bug where scatter additions produce incorrect gradients for batch rows past the first. I have no feedback to provide as there are no review comments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
308051c to
bf0744e
Compare
MLX 0.31.2 scatter-add can read the update tensor with a wrong batch stride for mx .at[:, t].add, silently corrupting d_q/d_k/d_v/d_g/d_beta for every batch element except index 0. Collect per-step grads in lists and stack instead. Verified against a plain autodiff recurrence: the old scatter path matches on mlx 0.31.1, diverges on mlx 0.31.2 for B>1, and the list+stack path matches on both. The underlying MLX regression is fixed upstream in ml-explore/mlx#3483, but our release venv is still on mlx 0.31.2.
bf0744e to
2c748ab
Compare
* test(mlx): batched gradient-parity regression for gated-delta VJP The existing test_mlx_gated_delta.py only checks finite gradients at B=1, so it passes on both the pre-#776 scatter-add code and the fixed list+stack code. Add a regression that pins the custom VJP against plain mx.value_and_grad of the unrolled recurrence at B>1, across single/multi-chunk, GQA, scalar and vectorized gating, and masked inputs, plus an identical-rows check that the per-row gradients match. On Metal this fails on the pre-#776 wrong-batch-stride scatter-add (rows past index 0 corrupt); on any backend a misordered stack or a reintroduced scatter-add also fails it. Requires a real mlx install and is skipped under the torch-shim CI. * test(mlx): address review on gated-delta batch-grad test - Reject the torch shim (tests/mlx_simulation) so these real-MLX tests skip cleanly instead of running under the shim's custom-function path. - Import gated_delta_vjp outside the availability guard so an import-time regression in the module under test fails loudly instead of skipping. - Remove unused mlx.nn import. - _rel_l2: keep the reduction in MLX and sync once instead of per-tensor.
What's changed
This is the narrow MLX gated-delta gradient workaround split out from #754.
MLX 0.31.2 scatter-add can read the update tensor with a wrong batch stride for
mx .at[:, t].add, silently corruptingd_q/d_k/d_v/d_g/d_betafor every batch element except index 0. This PR avoids that path by collecting per-step grads in lists and stacking them after the reverse-time loop.The MLX core issue only affects this gated-delta VJP path when
B > 1.B = 1is included in validation and performance checks as a control case.The underlying MLX regression is fixed upstream in ml-explore/mlx#3483, but it's not included in latest mlx release
v0.31.2.Validation
Checked against a plain autodiff recurrence:
mlx 0.31.1mlx 0.31.2forB > 1Local focused test:
python -m pytest tests/test_mlx_gated_delta.py -q->3 passedFocused VJP performance
Exact-import retest comparing the actual parent implementation (
d5224f48, old scatter) against the actual PR branch (2c748ab1, fixed list+stack). Workload: full backward/VJP,T=64,Hk=Hv=2,Dk=Dv=64, no mask. Table reports median-of-round-medians..at.addPerformance interpretation: the PR is a correctness workaround for the
mlx 0.31.2scatter regression. In this focused VJP microbenchmark, the fixed path is also faster than the old scatter implementation, but this table should not be read as an end-to-end training throughput claim.