Skip to content

Masked fused CE loss#616

Open
Datta0 wants to merge 2 commits into
unslothai:mainfrom
Datta0:masked_chunked_loss
Open

Masked fused CE loss#616
Datta0 wants to merge 2 commits into
unslothai:mainfrom
Datta0:masked_chunked_loss

Conversation

@Datta0

@Datta0 Datta0 commented May 1, 2026

Copy link
Copy Markdown
Collaborator

Fixes : unslothai/unsloth#5230

Respect masking when calculating chunked loss. Only accumulate if the chunk has non masked tokens

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a pre-filtering mechanism for chunks in the fused cross-entropy loss to optimize processing, along with a new suffix masking feature in the test suite. The review identified a critical issue where the test parameter lists were mismatched, causing the new test case to be skipped. Additionally, the implementation was flagged for potential performance bottlenecks due to excessive GPU-CPU synchronization within the chunk loop, and a logic error regarding the compilation status flag when processing empty chunks was noted.

Comment on lines +35 to +36
lm_head_requires_grads = [False, False, True, True, False,],
lm_bias_requires_grads = [False, True, False, True, False,],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The lists lm_head_requires_grads and lm_bias_requires_grads only have 4 elements, while the other parameter lists (like bszs, qlens, suffix_masks) have 5. Because of the zip in the test loop, the 5th test case (which specifically tests the new suffix_mask feature) will never be executed.

Suggested change
lm_head_requires_grads = [False, False, True, True, False,],
lm_bias_requires_grads = [False, True, False, True, False,],
lm_head_requires_grads = [False, False, True, True, True,],
lm_bias_requires_grads = [False, True, False, True, True,],

Comment on lines +328 to +334
chunks = []
for grad_inputs_j, hidden_states_j, labels_j in zip(__grad_inputs, __shift_states, __shift_labels):
if bool((labels_j != -100).any().item()):
chunks.append((grad_inputs_j, hidden_states_j, labels_j,))
else:
grad_inputs_j.zero_()
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a GPU-CPU synchronization (.item()) for every chunk in the loop. If n_chunks is large, this can lead to significant performance degradation due to multiple stalls. It is more efficient to compute the mask status for all chunks in a single operation and perform one synchronization.

Suggested change
chunks = []
for grad_inputs_j, hidden_states_j, labels_j in zip(__grad_inputs, __shift_states, __shift_labels):
if bool((labels_j != -100).any().item()):
chunks.append((grad_inputs_j, hidden_states_j, labels_j,))
else:
grad_inputs_j.zero_()
pass
chunks = []
if __shift_labels:
# Check all chunks for non-masked labels in a single sync to avoid multiple GPU-CPU stalls
has_labels = torch.stack([(l != -100).any() for l in __shift_labels]).cpu().tolist()
for i, (grad_inputs_j, hidden_states_j, labels_j) in enumerate(zip(__grad_inputs, __shift_states, __shift_labels)):
if has_labels[i]:
chunks.append((grad_inputs_j, hidden_states_j, labels_j,))
else:
grad_inputs_j.zero_()

Comment on lines 355 to 356
accumulate_chunk is not uncompiled_accumulate_chunk:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting _FUSED_CE_COMPILE_SUPPORTED = True when chunks is empty is premature. Since no chunks were processed, the compiled version of accumulate_chunk hasn't actually been executed/probed. If a subsequent call with non-empty chunks fails during compilation or execution, it will bypass the fallback logic and potentially crash. It's safer to leave the status as None so the next call with data can perform the probe.

Suggested change
accumulate_chunk is not uncompiled_accumulate_chunk:
except StopIteration:
pass

@Datta0 Datta0 marked this pull request as ready for review May 21, 2026 06:20
@Datta0 Datta0 requested a review from danielhanchen as a code owner May 21, 2026 06:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Gemma 4 26B-A4B-it: fused-loss kernel produces zero gradients with last-assistant-only label mask

1 participant