fix(loss): support THD/packed layout in FusedLinearCrossEntropy#2615
Open
akoumpa wants to merge 4 commits into
Open
fix(loss): support THD/packed layout in FusedLinearCrossEntropy#2615akoumpa wants to merge 4 commits into
akoumpa wants to merge 4 commits into
Conversation
FusedLinearCrossEntropy passed hidden_states and labels straight into cut_cross_entropy, which asserts `hidden_states.shape[:-1] == labels.shape`. That holds for the dense [B, S, H] / [B, S] case but breaks for the THD / packed-sequence layout (model emits hidden [1, T, H] while labels stay [B, S]) and for context-parallel shards, raising an AssertionError in cce.py before the first step — e.g. the qwen3_moe_30b_te_chat_thd recipe (AM-460), which uses the THD collater + CP for TE context parallelism. Fix: flatten hidden_states to [N, H] and labels to [N] before the call, mirroring MaskedCrossEntropy. This satisfies the assert for THD/CP and is a no-op for the dense case (the CE sum is invariant to token ordering). Verified on 8xH100 (cw-dfw) with the qwen3_moe_30b THD recipe: the AssertionError is gone and step 0 produces a correct loss (1.2350, matching the MaskedCrossEntropy baseline's 1.2705). Adds a CPU unit test that reproduces the THD shape mismatch and checks the inputs are reconciled to a single token axis. Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Contributor
Author
|
/ok to test 29f1b51 |
Contributor
Author
|
/ok to test 26863d5 |
Skip test if GPU is not available to prevent failures.
Contributor
Author
|
/ok to test 35af590 |
…rvives the flatten The L2 L2_FusedLinearCrossEntropy_HiddenStates_Minified driver expects the naive `hidden_states[-1]` extraction to raise AssertionError. With the THD/CP flatten now in FusedLinearCrossEntropy, batch_size=1 made the dropped-batch shape coincidentally reconcile ([8] vs [1,8] -> [8] vs [8]), so the assert no longer fired and CPU tensors reached the cut_cross_entropy Triton kernel: ValueError: Pointer argument cannot be accessed from Triton (cpu tensor?). Use batch_size=2 so `hidden_states[-1]` genuinely drops batch elements (T vs B*T tokens), restoring a real token-count mismatch that cut_cross_entropy rejects before any kernel launch. Verified the L2 script passes locally on GPU. Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Contributor
Author
|
/ok to test 414e99e |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
FusedLinearCrossEntropy(cut_cross_entropy / CCE) fails with anAssertionErroron the THD / packed-sequence layout, before the firsttraining step.
CCE asserts
hidden_states.shape[:-1] == labels.shape(cce.py). That holds forthe dense
[B, S, H]/[B, S]case, but the THD path emits hidden states as[1, T, H]while labels stay[B, S](and context-parallel shards skew theleading dims further), so the assert fires:
Surfaced by
examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml(uses the THD collater + CP for TE context parallelism) — tracked in AM-460.
Fix
Flatten
hidden_statesto[N, H]andlabelsto[N]before the CCE call,mirroring what
MaskedCrossEntropyalready does. This satisfies the assert forthe THD/CP layouts and is a no-op for the dense case — the cross-entropy sum
is invariant to token ordering, and
num_label_tokensnormalization isunchanged.
Testing
test_fused_cross_entropy_thd_shape_reconciliation) feedsTHD-shaped inputs (
hidden [1, T, H],labels [B, S]) and asserts the forwardreconciles them to a single token axis (reproduces the bug without the fix).
qwen3_moe_30b_te_chat_thdrecipe inthe CI container: the
AssertionErroris gone and step 0 produces a correctloss (
1.2350, matching theMaskedCrossEntropybaseline's1.2705).