Skip to content

fix(loss): support THD/packed layout in FusedLinearCrossEntropy#2615

Open
akoumpa wants to merge 4 commits into
mainfrom
akoumpa/fix/linear-ce-thd
Open

fix(loss): support THD/packed layout in FusedLinearCrossEntropy#2615
akoumpa wants to merge 4 commits into
mainfrom
akoumpa/fix/linear-ce-thd

Conversation

@akoumpa

@akoumpa akoumpa commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Summary

FusedLinearCrossEntropy (cut_cross_entropy / CCE) fails with an
AssertionError on the THD / packed-sequence layout, before the first
training step.

CCE asserts hidden_states.shape[:-1] == labels.shape (cce.py). That holds for
the dense [B, S, H] / [B, S] case, but the THD path emits hidden states as
[1, T, H] while labels stay [B, S] (and context-parallel shards skew the
leading dims further), so the assert fires:

File ".../cut_cross_entropy/cce.py", line 172, in cce_linear_cross_entropy
    assert e.size()[0:-1] == targets.size()
AssertionError

Surfaced by examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml
(uses the THD collater + CP for TE context parallelism) — tracked in AM-460.

Fix

Flatten hidden_states to [N, H] and labels to [N] before the CCE call,
mirroring what MaskedCrossEntropy already does. This satisfies the assert for
the THD/CP layouts and is a no-op for the dense case — the cross-entropy sum
is invariant to token ordering, and num_label_tokens normalization is
unchanged.

Testing

  • New CPU unit test (test_fused_cross_entropy_thd_shape_reconciliation) feeds
    THD-shaped inputs (hidden [1, T, H], labels [B, S]) and asserts the forward
    reconciles them to a single token axis (reproduces the bug without the fix).
  • Verified on 8×H100 (cw-dfw) with the qwen3_moe_30b_te_chat_thd recipe in
    the CI container: the AssertionError is gone and step 0 produces a correct
    loss (1.2350, matching the MaskedCrossEntropy baseline's 1.2705).

Note: this unblocks fused linear CE for THD recipes; it is necessary but not by
itself sufficient to resolve the AM-460 OOM (that is a separate full-FT
optimizer-state capacity issue, tracked separately).

FusedLinearCrossEntropy passed hidden_states and labels straight into
cut_cross_entropy, which asserts `hidden_states.shape[:-1] == labels.shape`.
That holds for the dense [B, S, H] / [B, S] case but breaks for the THD /
packed-sequence layout (model emits hidden [1, T, H] while labels stay [B, S])
and for context-parallel shards, raising an AssertionError in cce.py before the
first step — e.g. the qwen3_moe_30b_te_chat_thd recipe (AM-460), which uses the
THD collater + CP for TE context parallelism.

Fix: flatten hidden_states to [N, H] and labels to [N] before the call,
mirroring MaskedCrossEntropy. This satisfies the assert for THD/CP and is a
no-op for the dense case (the CE sum is invariant to token ordering).

Verified on 8xH100 (cw-dfw) with the qwen3_moe_30b THD recipe: the
AssertionError is gone and step 0 produces a correct loss (1.2350, matching the
MaskedCrossEntropy baseline's 1.2705).

Adds a CPU unit test that reproduces the THD shape mismatch and checks the
inputs are reconciled to a single token axis.

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa akoumpa requested a review from a team as a code owner June 17, 2026 02:59
@copy-pr-bot

copy-pr-bot Bot commented Jun 17, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa akoumpa added the r0.5.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge. label Jun 17, 2026
@akoumpa

akoumpa commented Jun 17, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 29f1b51

@akoumpa

akoumpa commented Jun 17, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 26863d5

Skip test if GPU is not available to prevent failures.
@akoumpa

akoumpa commented Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 35af590

…rvives the flatten

The L2 L2_FusedLinearCrossEntropy_HiddenStates_Minified driver expects the naive
`hidden_states[-1]` extraction to raise AssertionError. With the THD/CP flatten
now in FusedLinearCrossEntropy, batch_size=1 made the dropped-batch shape
coincidentally reconcile ([8] vs [1,8] -> [8] vs [8]), so the assert no longer
fired and CPU tensors reached the cut_cross_entropy Triton kernel:
ValueError: Pointer argument cannot be accessed from Triton (cpu tensor?).

Use batch_size=2 so `hidden_states[-1]` genuinely drops batch elements (T vs B*T
tokens), restoring a real token-count mismatch that cut_cross_entropy rejects
before any kernel launch. Verified the L2 script passes locally on GPU.

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa

akoumpa commented Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 414e99e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

r0.5.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant