Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions unsloth_zoo/gated_delta_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,17 @@ def _chunked_vjp(primals, cotangents, outputs):
# BPTT: `d_state` is the cotangent w.r.t. the RETURNED state at the
# current step (= input to step t+1). Starts at d_state_out, then
# propagates through the recurrence + mask.
d_q = mx.zeros_like(q_c)
d_k = mx.zeros_like(k_c)
d_v = mx.zeros_like(v_c)
d_g = mx.zeros_like(g_c)
d_beta = mx.zeros_like(beta_c)
# Per-step grads are collected in lists and stacked afterwards: each
# t is produced exactly once, and mx `.at[:, t].add` scatter-add
# reads the update tensor with a wrong batch stride (wrong grads for
# every batch row past the first; verified against plain autodiff on
# mlx 0.31).
# Fixed upstream in ml-explore/mlx#3483, not yet in any release.
d_q_steps = []
d_k_steps = []
d_v_steps = []
d_g_steps = []
d_beta_steps = []
d_state = d_state_out

for t in range(chunk_T - 1, -1, -1):
Expand Down Expand Up @@ -171,7 +177,7 @@ def _chunked_vjp(primals, cotangents, outputs):
+ dy_t[..., None].astype(mx.float32) * q_t[..., None, :].astype(mx.float32)
)
d_q_t = (dy_t[..., None].astype(mx.float32) * state_new).sum(axis=-2)
d_q = d_q.at[:, t].add(d_q_t.astype(d_q.dtype))
d_q_steps.append(d_q_t.astype(q_c.dtype))

# state_new = state_decayed + k[..., None, :] * delta[..., None]
d_kd = d_state_new
Expand Down Expand Up @@ -203,14 +209,21 @@ def _chunked_vjp(primals, cotangents, outputs):
d_g_t = d_decay

d_k_t = d_k_t_from_update + d_k_t_from_kv
d_k = d_k.at[:, t].add(d_k_t.astype(d_k.dtype))
d_v = d_v.at[:, t].add(d_v_t.astype(d_v.dtype))
d_g = d_g.at[:, t].add(d_g_t.astype(d_g.dtype))
d_beta = d_beta.at[:, t].add(d_beta_t.astype(d_beta.dtype))
d_k_steps.append(d_k_t.astype(k_c.dtype))
d_v_steps.append(d_v_t.astype(v_c.dtype))
d_g_steps.append(d_g_t.astype(g_c.dtype))
d_beta_steps.append(d_beta_t.astype(beta_c.dtype))

# d_state_prev = recurrence-derived gradient + mask passthrough.
d_state = d_state_prev_via_recurrence + d_state_prev_passthrough

for steps in (d_q_steps, d_k_steps, d_v_steps, d_g_steps, d_beta_steps):
steps.reverse()
d_q = mx.stack(d_q_steps, axis=1)
d_k = mx.stack(d_k_steps, axis=1)
d_v = mx.stack(d_v_steps, axis=1)
d_g = mx.stack(d_g_steps, axis=1)
d_beta = mx.stack(d_beta_steps, axis=1)
d_mask = mx.zeros_like(mask_chunk) if mask_chunk is not None else None
return d_q, d_k, d_v, d_g, d_beta, d_state, d_mask

Expand Down
Loading