Skip to content

fix(mlx): wrong gated-delta grads for batch rows past the first#776

Merged
danielhanchen merged 1 commit into
unslothai:mainfrom
Lyxot:fix/mlx-gated-delta-batch-grad
Jun 15, 2026
Merged

fix(mlx): wrong gated-delta grads for batch rows past the first#776
danielhanchen merged 1 commit into
unslothai:mainfrom
Lyxot:fix/mlx-gated-delta-batch-grad

Conversation

@Lyxot

@Lyxot Lyxot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

What's changed

This is the narrow MLX gated-delta gradient workaround split out from #754.

MLX 0.31.2 scatter-add can read the update tensor with a wrong batch stride for mx .at[:, t].add, silently corrupting d_q/d_k/d_v/d_g/d_beta for every batch element except index 0. This PR avoids that path by collecting per-step grads in lists and stacking them after the reverse-time loop.

The MLX core issue only affects this gated-delta VJP path when B > 1. B = 1 is included in validation and performance checks as a control case.

The underlying MLX regression is fixed upstream in ml-explore/mlx#3483, but it's not included in latest mlx release v0.31.2.

Validation

Checked against a plain autodiff recurrence:

  • old scatter path matches on mlx 0.31.1
  • old scatter path diverges on mlx 0.31.2 for B > 1
  • list+stack path matches on both versions

Local focused test:

  • python -m pytest tests/test_mlx_gated_delta.py -q -> 3 passed

Focused VJP performance

Exact-import retest comparing the actual parent implementation (d5224f48, old scatter) against the actual PR branch (2c748ab1, fixed list+stack). Workload: full backward/VJP, T=64, Hk=Hv=2, Dk=Dv=64, no mask. Table reports median-of-round-medians.

MLX B old .at.add fixed branch fixed delta
0.31.2 1 12.93 ms 12.19 ms 5.7% faster
0.31.2 2 13.34 ms 12.33 ms 7.6% faster
0.31.2 4 13.15 ms 12.46 ms 5.3% faster
0.31.2 8 13.97 ms 13.10 ms 6.3% faster
0.31.1 1 13.12 ms 11.55 ms 12.0% faster
0.31.1 2 13.08 ms 11.62 ms 11.2% faster
0.31.1 4 13.23 ms 11.75 ms 11.2% faster
0.31.1 8 13.60 ms 12.36 ms 9.1% faster

Performance interpretation: the PR is a correctness workaround for the mlx 0.31.2 scatter regression. In this focused VJP microbenchmark, the fixed path is also faster than the old scatter implementation, but this table should not be read as an end-to-end training throughput claim.

Copilot AI review requested due to automatic review settings June 15, 2026 11:05

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR addresses incorrect per-step gradient accumulation in _chunked_vjp by avoiding MLX .at[:, t].add(...) scatter updates and instead collecting step gradients in Python lists and stacking at the end.

Changes:

  • Replace .at[:, t].add(...) gradient accumulation with per-step list collection.
  • Stack collected per-step gradients into (B, T, ...) tensors after the backward loop.
  • Add an inline comment explaining the MLX scatter bug and linking the upstream fix.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread unsloth_zoo/gated_delta_vjp.py Outdated
Comment thread unsloth_zoo/gated_delta_vjp.py Outdated

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies unsloth_zoo/gated_delta_vjp.py to collect per-step gradients in lists and stack them after the loop, rather than using in-place scatter additions (.at[:, t].add). This change works around an upstream MLX bug where scatter additions produce incorrect gradients for batch rows past the first. I have no feedback to provide as there are no review comments.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@Lyxot Lyxot force-pushed the fix/mlx-gated-delta-batch-grad branch from 308051c to bf0744e Compare June 15, 2026 11:12
MLX 0.31.2 scatter-add can read the update tensor with a wrong batch stride for mx .at[:, t].add, silently corrupting d_q/d_k/d_v/d_g/d_beta for every batch element except index 0. Collect per-step grads in lists and stack instead.

Verified against a plain autodiff recurrence: the old scatter path matches on mlx 0.31.1, diverges on mlx 0.31.2 for B>1, and the list+stack path matches on both. The underlying MLX regression is fixed upstream in ml-explore/mlx#3483, but our release venv is still on mlx 0.31.2.
@Lyxot Lyxot force-pushed the fix/mlx-gated-delta-batch-grad branch from bf0744e to 2c748ab Compare June 15, 2026 11:12
@danielhanchen danielhanchen merged commit e71e5ba into unslothai:main Jun 15, 2026
11 checks passed
danielhanchen added a commit that referenced this pull request Jun 16, 2026
* test(mlx): batched gradient-parity regression for gated-delta VJP

The existing test_mlx_gated_delta.py only checks finite gradients at B=1, so
it passes on both the pre-#776 scatter-add code and the fixed list+stack code.
Add a regression that pins the custom VJP against plain mx.value_and_grad of
the unrolled recurrence at B>1, across single/multi-chunk, GQA, scalar and
vectorized gating, and masked inputs, plus an identical-rows check that the
per-row gradients match. On Metal this fails on the pre-#776 wrong-batch-stride
scatter-add (rows past index 0 corrupt); on any backend a misordered stack or a
reintroduced scatter-add also fails it. Requires a real mlx install and is
skipped under the torch-shim CI.

* test(mlx): address review on gated-delta batch-grad test

- Reject the torch shim (tests/mlx_simulation) so these real-MLX tests skip
  cleanly instead of running under the shim's custom-function path.
- Import gated_delta_vjp outside the availability guard so an import-time
  regression in the module under test fails loudly instead of skipping.
- Remove unused mlx.nn import.
- _rel_l2: keep the reduction in MLX and sync once instead of per-tensor.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants