Masked fused CE loss#616
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a pre-filtering mechanism for chunks in the fused cross-entropy loss to optimize processing, along with a new suffix masking feature in the test suite. The review identified a critical issue where the test parameter lists were mismatched, causing the new test case to be skipped. Additionally, the implementation was flagged for potential performance bottlenecks due to excessive GPU-CPU synchronization within the chunk loop, and a logic error regarding the compilation status flag when processing empty chunks was noted.
| lm_head_requires_grads = [False, False, True, True, False,], | ||
| lm_bias_requires_grads = [False, True, False, True, False,], |
There was a problem hiding this comment.
The lists lm_head_requires_grads and lm_bias_requires_grads only have 4 elements, while the other parameter lists (like bszs, qlens, suffix_masks) have 5. Because of the zip in the test loop, the 5th test case (which specifically tests the new suffix_mask feature) will never be executed.
| lm_head_requires_grads = [False, False, True, True, False,], | |
| lm_bias_requires_grads = [False, True, False, True, False,], | |
| lm_head_requires_grads = [False, False, True, True, True,], | |
| lm_bias_requires_grads = [False, True, False, True, True,], |
| chunks = [] | ||
| for grad_inputs_j, hidden_states_j, labels_j in zip(__grad_inputs, __shift_states, __shift_labels): | ||
| if bool((labels_j != -100).any().item()): | ||
| chunks.append((grad_inputs_j, hidden_states_j, labels_j,)) | ||
| else: | ||
| grad_inputs_j.zero_() | ||
| pass |
There was a problem hiding this comment.
The current implementation performs a GPU-CPU synchronization (.item()) for every chunk in the loop. If n_chunks is large, this can lead to significant performance degradation due to multiple stalls. It is more efficient to compute the mask status for all chunks in a single operation and perform one synchronization.
| chunks = [] | |
| for grad_inputs_j, hidden_states_j, labels_j in zip(__grad_inputs, __shift_states, __shift_labels): | |
| if bool((labels_j != -100).any().item()): | |
| chunks.append((grad_inputs_j, hidden_states_j, labels_j,)) | |
| else: | |
| grad_inputs_j.zero_() | |
| pass | |
| chunks = [] | |
| if __shift_labels: | |
| # Check all chunks for non-masked labels in a single sync to avoid multiple GPU-CPU stalls | |
| has_labels = torch.stack([(l != -100).any() for l in __shift_labels]).cpu().tolist() | |
| for i, (grad_inputs_j, hidden_states_j, labels_j) in enumerate(zip(__grad_inputs, __shift_states, __shift_labels)): | |
| if has_labels[i]: | |
| chunks.append((grad_inputs_j, hidden_states_j, labels_j,)) | |
| else: | |
| grad_inputs_j.zero_() |
| accumulate_chunk is not uncompiled_accumulate_chunk: | ||
|
|
There was a problem hiding this comment.
Setting _FUSED_CE_COMPILE_SUPPORTED = True when chunks is empty is premature. Since no chunks were processed, the compiled version of accumulate_chunk hasn't actually been executed/probed. If a subsequent call with non-empty chunks fails during compilation or execution, it will bypass the fallback logic and potentially crash. It's safer to leave the status as None so the next call with data can perform the probe.
| accumulate_chunk is not uncompiled_accumulate_chunk: | |
| except StopIteration: | |
| pass |
Fixes : unslothai/unsloth#5230
Respect masking when calculating chunked loss. Only accumulate if the chunk has non masked tokens