Skip to content

Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090

Open
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes
Open

Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Problem

slime aggregates pg_loss across a training step as a per-rollout token-weighted sample mean. Recent RL recipes deliberately choose a different aggregation: DAPO averages per prompt group, Dr.GRPO divides by a constant, and some recipes use a global per-token mean. ScaleRL (arXiv:2510.13786 §3.2) catalogs these as one knob and reports the choice materially affects stability and final reward. Today the only escape hatch is --custom-pg-loss-reducer-function-path (write your own reducer per recipe); there is no first-class flag, and the prior #2060 only adds a single Dr.GRPO --pg-loss-divisor.

What this adds

A single --loss-aggregation {sample_mean,prompt_mean,token_mean,constant} flag (plus --loss-aggregation-divisor L for constant) selecting how pg_loss is aggregated. Modes follow the ScaleRL taxonomy:

Mode Paper pg_loss denominator
sample_mean (default) GRPO sample average Per-rollout token-weighted mean (each rollout contributes equally regardless of fan-out). Byte-identical to slime's prior default.
prompt_mean DAPO prompt average Per-prompt-group token-weighted mean (all rollouts sharing a Sample.group_index share one denominator). ScaleRL's recommended default for new recipes.
token_mean token average Global per-token mean. This is exactly what the legacy --calculate-per-token-loss flag does; the two are reconciled at startup (see below).
constant Dr.GRPO (arXiv:2503.20783) sum(token_loss * loss_mask) / L, L = --loss-aggregation-divisor (e.g. the max context length).

Before / after

  • Before: the pg_loss reducer is the per-rollout sum_of_sample_mean. No flag.
  • After: the pg_loss reducer is chosen by --loss-aggregation. The default sample_mean returns the same reducer object — so an existing run's pg_loss is byte-identical (verified by test_default_reduces_to_per_sample_mean and the validation-side test_loss_aggregation_default_leaves_per_token_loss_off). L is validated > 0 at startup, only for constant.

Why this shape

It rides the sample_denoms seam already added in #1933 to cp_utils.get_sum_of_sample_mean. That function takes pre-computed per-sample denominators that are CP-correct and remain correct when a rollout's samples are packed across micro-batches. The default path feeds rollout_mask_sums (per-rollout totals).

  • prompt_mean is the only new step-level computation: prompt_mask_sums (per-prompt-group mask totals grouped by Sample.group_index), computed in RolloutManager right beside rollout_mask_sums and plumbed through the identical path (data.py log filter, DP split, model.py pad list, actor GPU promotion). It is a per-sample broadcast of the whole-group total, so CP, the DP split, and micro-batch packing all sum partial (x·m)/D_group to (Σx·m)/D_group — same correctness as sample_mean, for free.
  • constant is a small constant_divisor branch in get_sum_of_sample_mean (sum_of_token(x) / L); being identical on every CP rank, Megatron's gradient sum-allreduce already yields the full-batch value (no extra all-reduce).
  • token_mean reuses the existing --calculate-per-token-loss path so the loss-scaling and reporting stays consistent rather than introducing a second per-token codepath.

prompt_mask_sums is computed only under --loss-aggregation prompt_mean (the mode that consumes it); the other three modes never read Sample.group_index and never build the key, so the default (and every non-prompt_mean) batch is unchanged — no extra batch key, no always-on group-aggregation compute/bandwidth.

This keeps the new code minimal and reuses the verified reducer rather than adding a parallel aggregation stack.

Reconciling --calculate-per-token-loss (one honest axis)

--calculate-per-token-loss is not a separate axis — it is the token_mean point on --loss-aggregation. The two spellings are reconciled in slime_validate_args so there is exactly one knob and the reported objective is honest:

  • Forward: --loss-aggregation=token_mean sets --calculate-per-token-loss=True (drives the per-token reducer / normalizer path).
  • Backward: a bare --calculate-per-token-loss on the default sample_mean is relabeled to token_mean. This is a pure label fix with no behavior changesample_mean already returned the per-token reducer when the flag was set, so the loss is byte-identical; only the reported mode name changes. Existing --calculate-per-token-loss recipes keep working unchanged.
  • Fail loud: --calculate-per-token-loss combined with prompt_mean or constant (distinct objectives) raises at startup, rather than silently letting the per-token path override the prompt-group / constant denominator.

Prefer --loss-aggregation=token_mean in new recipes.

Effective normalization for constant

The constant divisor L is per-token, not per-step. Each mode's reducer returns a step sum that is then averaged by the usual / step_global_batch_size (identical structure across all four modes). So the effective per-step normalization for constant is / (L * step_global_batch_size): L sets the data-independent per-token scale, and the / step_global_batch_size step average is applied on top exactly as for the other modes. Documented in customization.md and the --loss-aggregation-divisor help.

Scope: pg_loss only (deliberate)

Aggregation applies to pg_loss only. The diagnostic metrics — pg_clipfrac, ppo_kl, entropy_loss, kl_loss — keep the default sample-mean reducer so they stay interpretable and comparable across runs (e.g. a constant /L must not crush ppo_kl by the same factor and make it unreadable). This matches the existing scope of --custom-pg-loss-reducer-function-path, which still takes precedence when set. token_mean is the documented exception: because it reuses --calculate-per-token-loss, it is per-token everywhere.

Alternative for reviewers: one could apply the chosen aggregation uniformly to the metrics too (single normalizer for loss and diagnostics). We chose not to, to preserve metric comparability across aggregation modes; this is a one-line change at the call site if the project prefers the uniform convention.

Honesty: prompt_mean absolute scale

prompt_mean weights every prompt group equally — each group's token-weighted mean enters the step sum once, all under the same / step_global_batch_size divisor — so the relative per-prompt weighting (the property DAPO is about) is exact. Its absolute scale differs from a strict 1/P DAPO average by a constant factor (P / N, prompts over rollouts), which the learning rate absorbs. Documented in docs/en/get_started/customization.md and the flag help.

Fail-loud guards (no silent degradation)

Every way to misconfigure the new knob is rejected at startup or fails structurally, so no run silently normalizes against the wrong denominator:

  • constant requires --loss-aggregation-divisor > 0; None, 0, negative, and NaN are all rejected (not (divisor > 0) catches NaN).
  • --loss-aggregation-divisor set on any non-constant mode raises (it would otherwise be a silent no-op that misleads about normalization).
  • --calculate-per-token-loss combined with prompt_mean/constant raises (distinct objectives — see reconciliation above).
  • prompt_mean requires global_batch_size % n_samples_per_prompt == 0. Each training step is a contiguous slice of global_batch_size rollouts; if a prompt group straddled a step boundary its per-group denominator would fragment across optimizer updates, so this is rejected at startup.
  • prompt_mean requires every Sample.group_index to be set. A None would silently collapse unrelated prompts into one denominator (degrading prompt_meansample_mean for that sample), so the convert step raises instead of renumbering it into a singleton group.
  • The reducer batch.get("prompt_mask_sums")-checks under prompt_mean and raises if a custom convert path selected prompt_mean but dropped the key (rather than degrading to the per-sample mean).
  • get_sum_of_sample_mean zips its per-sample sample_denoms with strict=True: a caller that supplies a mismatched-length sample_denoms (a construction bug) fails loud instead of silently dropping samples.

Tests

  • tests/test_cp_utils.py: constant divides the masked token-sum by L; prompt_mean's per-group denominator is distinct from the per-rollout (sample_mean) and per-token (token_mean) results on uneven fixtures; the constant / --calculate-per-token-loss mutual-exclusion guard; CP rank-sum invariance for the new constant branch (prompt_mean reuses the per-rollout sample-mean CP path already pinned in the same file); the strict=True length-mismatch guard.
  • tests/test_megatron_argument_validation.py: --loss-aggregation-divisor rejected when missing / non-positive / NaN under constant, accepted when positive; rejected on every non-constant mode; token_mean aliases --calculate-per-token-loss; a bare --calculate-per-token-loss reconciles to token_mean; prompt_mean/constant + --calculate-per-token-loss rejected; prompt_mean requires global_batch_size a multiple of n_samples_per_prompt (and the other modes do not); default leaves the per-token flag off.
  • tests/test_rollout_validation.py: the prompt_mask_sums build in _convert_samples_to_train_data — fails loud on a None group_index under prompt_mean, builds the correct per-prompt-group totals, and (for sample_mean / constant / token_mean) neither consults group_index nor adds the prompt_mask_sums key, keeping the non-prompt_mean batch unchanged.

All guards are mutation-verified: removing a guard fails its test.

Supersedes #2060.

Add a unified `--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}`
(+ `--loss-aggregation-divisor L` for constant) selecting how pg_loss is
aggregated across a training step, riding the existing `sample_denoms` seam in
`get_sum_of_sample_mean`.

Modes follow the ScaleRL taxonomy (arXiv:2510.13786 §3.2):
- sample_mean (default): GRPO sample average — per-rollout token-weighted mean
  via `rollout_mask_sums`. Byte-identical to the prior default (no extra batch
  key in any non-prompt_mean mode).
- prompt_mean: DAPO prompt average — step-level `prompt_mask_sums` grouped by
  Sample.group_index, built ONLY under prompt_mean and plumbed like
  `rollout_mask_sums` (CP- and variable-GBS-correct). The other three modes
  never read group_index and never build the key. A None group_index under
  prompt_mean fails loud (the prompt-grouping invariant is broken; silently
  renumbering it into a singleton group would degrade prompt_mean -> sample_mean
  for that sample). Every prompt group enters the step sum once under the same
  `/ step_global_batch_size` divisor, so relative per-prompt weighting is
  uniform; absolute scale differs from a strict 1/P DAPO average by a constant
  factor (P/N), which the learning rate absorbs.
- token_mean: token average — the global per-token mean. This is exactly the
  legacy `--calculate-per-token-loss` flag; the two spellings are now one axis.
- constant: Dr.GRPO (arXiv:2503.20783) — masked token sum / L via a new
  `constant_divisor` branch in cp_utils.

Aggregation applies to pg_loss only (metrics keep sum_of_sample_mean);
`--custom-pg-loss-reducer-function-path` still takes precedence. `L` is validated
> 0 at startup only for constant. Supersedes the open THUDM#2060 `--pg-loss-divisor`.

Consolidate `--calculate-per-token-loss` with `--loss-aggregation=token_mean`
into one coherent axis, backward-compatibly. `--calculate-per-token-loss` is the
legacy spelling of `token_mean` (get_pg_loss_reducer returns the default reducer
for both sample_mean and token_mean, and that default reducer is the per-token
path iff `calculate_per_token_loss` is set). slime_validate_args reconciles them
bidirectionally: `token_mean` sets `calculate_per_token_loss=True` (forward), and
the legacy flag on the default `sample_mean` is relabeled to `token_mean`
(backward) — no behavior change (sample_mean already returned the per-token
default reducer when the flag was set), which closes the last silent
sample_mean->token_mean override. The flag is kept as a legacy alias (Megatron-
inherited; existing recipes keep working, just honestly labeled); docs and help
point new recipes at `--loss-aggregation=token_mean`.

Fail-loud guards (each load-bearing and mutation-verified):
- Reject prompt_mean and constant combined with --calculate-per-token-loss
  (both would silently override the reducer's denominator with the global
  per-token mean); reject a stray --loss-aggregation-divisor on any
  non-constant mode (silently ignored otherwise); structurally drop
  calculate_per_token_loss from the prompt_mean reducer call.
- cp_utils: zip sample_denoms strict=True so a per-sample denom length
  mismatch fails loud (equal to loss_masks by construction).
- Document the constant effective denominator (/ (L * global_batch_size)) and
  the whole-rollout prompt_mask_sums invariant (groups never partial across mb).
- Reject prompt_mean unless global_batch_size is a multiple of
  n_samples_per_prompt: dp_schedule slices each step as a contiguous run of
  global_batch_size rollouts, so a non-multiple lets a prompt group straddle a
  step boundary and fragments its per-group normalization across optimizer
  updates (wrong objective). Enforced in slime_validate_args after
  global_batch_size is finalized.

Tests: test_cp_utils.py pins constant divides by L, prompt_mean's per-group
denominator distinct from sample/token mean, the constant/per-token mutual-
exclusion guard, and CP rank-sum invariance for the new constant branch.
test_megatron_argument_validation.py pins divisor validation, the token_mean
alias both directions (token_mean sets the flag; the bare legacy flag reconciles
to token_mean), the per-token-incompatibility guards, the default sample_mean
staying per-token-off, and the prompt_mean single-step-per-group requirement.
test_rollout_validation.py pins the prompt_mask_sums build. Mutation-verified.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant