Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090
Open
EazyReal wants to merge 1 commit into
Open
Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090EazyReal wants to merge 1 commit into
--loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090EazyReal wants to merge 1 commit into
Conversation
371e70c to
40d955f
Compare
Add a unified `--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}`
(+ `--loss-aggregation-divisor L` for constant) selecting how pg_loss is
aggregated across a training step, riding the existing `sample_denoms` seam in
`get_sum_of_sample_mean`.
Modes follow the ScaleRL taxonomy (arXiv:2510.13786 §3.2):
- sample_mean (default): GRPO sample average — per-rollout token-weighted mean
via `rollout_mask_sums`. Byte-identical to the prior default (no extra batch
key in any non-prompt_mean mode).
- prompt_mean: DAPO prompt average — step-level `prompt_mask_sums` grouped by
Sample.group_index, built ONLY under prompt_mean and plumbed like
`rollout_mask_sums` (CP- and variable-GBS-correct). The other three modes
never read group_index and never build the key. A None group_index under
prompt_mean fails loud (the prompt-grouping invariant is broken; silently
renumbering it into a singleton group would degrade prompt_mean -> sample_mean
for that sample). Every prompt group enters the step sum once under the same
`/ step_global_batch_size` divisor, so relative per-prompt weighting is
uniform; absolute scale differs from a strict 1/P DAPO average by a constant
factor (P/N), which the learning rate absorbs.
- token_mean: token average — the global per-token mean. This is exactly the
legacy `--calculate-per-token-loss` flag; the two spellings are now one axis.
- constant: Dr.GRPO (arXiv:2503.20783) — masked token sum / L via a new
`constant_divisor` branch in cp_utils.
Aggregation applies to pg_loss only (metrics keep sum_of_sample_mean);
`--custom-pg-loss-reducer-function-path` still takes precedence. `L` is validated
> 0 at startup only for constant. Supersedes the open THUDM#2060 `--pg-loss-divisor`.
Consolidate `--calculate-per-token-loss` with `--loss-aggregation=token_mean`
into one coherent axis, backward-compatibly. `--calculate-per-token-loss` is the
legacy spelling of `token_mean` (get_pg_loss_reducer returns the default reducer
for both sample_mean and token_mean, and that default reducer is the per-token
path iff `calculate_per_token_loss` is set). slime_validate_args reconciles them
bidirectionally: `token_mean` sets `calculate_per_token_loss=True` (forward), and
the legacy flag on the default `sample_mean` is relabeled to `token_mean`
(backward) — no behavior change (sample_mean already returned the per-token
default reducer when the flag was set), which closes the last silent
sample_mean->token_mean override. The flag is kept as a legacy alias (Megatron-
inherited; existing recipes keep working, just honestly labeled); docs and help
point new recipes at `--loss-aggregation=token_mean`.
Fail-loud guards (each load-bearing and mutation-verified):
- Reject prompt_mean and constant combined with --calculate-per-token-loss
(both would silently override the reducer's denominator with the global
per-token mean); reject a stray --loss-aggregation-divisor on any
non-constant mode (silently ignored otherwise); structurally drop
calculate_per_token_loss from the prompt_mean reducer call.
- cp_utils: zip sample_denoms strict=True so a per-sample denom length
mismatch fails loud (equal to loss_masks by construction).
- Document the constant effective denominator (/ (L * global_batch_size)) and
the whole-rollout prompt_mask_sums invariant (groups never partial across mb).
- Reject prompt_mean unless global_batch_size is a multiple of
n_samples_per_prompt: dp_schedule slices each step as a contiguous run of
global_batch_size rollouts, so a non-multiple lets a prompt group straddle a
step boundary and fragments its per-group normalization across optimizer
updates (wrong objective). Enforced in slime_validate_args after
global_batch_size is finalized.
Tests: test_cp_utils.py pins constant divides by L, prompt_mean's per-group
denominator distinct from sample/token mean, the constant/per-token mutual-
exclusion guard, and CP rank-sum invariance for the new constant branch.
test_megatron_argument_validation.py pins divisor validation, the token_mean
alias both directions (token_mean sets the flag; the bare legacy flag reconciles
to token_mean), the per-token-incompatibility guards, the default sample_mean
staying per-token-off, and the prompt_mean single-step-per-group requirement.
test_rollout_validation.py pins the prompt_mask_sums build. Mutation-verified.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
40d955f to
3774a73
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
slime aggregates
pg_lossacross a training step as a per-rollout token-weighted sample mean. Recent RL recipes deliberately choose a different aggregation: DAPO averages per prompt group, Dr.GRPO divides by a constant, and some recipes use a global per-token mean. ScaleRL (arXiv:2510.13786 §3.2) catalogs these as one knob and reports the choice materially affects stability and final reward. Today the only escape hatch is--custom-pg-loss-reducer-function-path(write your own reducer per recipe); there is no first-class flag, and the prior #2060 only adds a single Dr.GRPO--pg-loss-divisor.What this adds
A single
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}flag (plus--loss-aggregation-divisor Lforconstant) selecting howpg_lossis aggregated. Modes follow the ScaleRL taxonomy:sample_mean(default)prompt_meanSample.group_indexshare one denominator). ScaleRL's recommended default for new recipes.token_mean--calculate-per-token-lossflag does; the two are reconciled at startup (see below).constantsum(token_loss * loss_mask) / L,L = --loss-aggregation-divisor(e.g. the max context length).Before / after
pg_lossreducer is the per-rolloutsum_of_sample_mean. No flag.pg_lossreducer is chosen by--loss-aggregation. The defaultsample_meanreturns the same reducer object — so an existing run'spg_lossis byte-identical (verified bytest_default_reduces_to_per_sample_meanand the validation-sidetest_loss_aggregation_default_leaves_per_token_loss_off).Lis validated> 0at startup, only forconstant.Why this shape
It rides the
sample_denomsseam already added in #1933 tocp_utils.get_sum_of_sample_mean. That function takes pre-computed per-sample denominators that are CP-correct and remain correct when a rollout's samples are packed across micro-batches. The default path feedsrollout_mask_sums(per-rollout totals).prompt_meanis the only new step-level computation:prompt_mask_sums(per-prompt-group mask totals grouped bySample.group_index), computed inRolloutManagerright besiderollout_mask_sumsand plumbed through the identical path (data.pylog filter, DP split,model.pypad list, actor GPU promotion). It is a per-sample broadcast of the whole-group total, so CP, the DP split, and micro-batch packing all sum partial(x·m)/D_groupto(Σx·m)/D_group— same correctness assample_mean, for free.constantis a smallconstant_divisorbranch inget_sum_of_sample_mean(sum_of_token(x) / L); being identical on every CP rank, Megatron's gradient sum-allreduce already yields the full-batch value (no extra all-reduce).token_meanreuses the existing--calculate-per-token-losspath so the loss-scaling and reporting stays consistent rather than introducing a second per-token codepath.prompt_mask_sumsis computed only under--loss-aggregation prompt_mean(the mode that consumes it); the other three modes never readSample.group_indexand never build the key, so the default (and every non-prompt_mean) batch is unchanged — no extra batch key, no always-on group-aggregation compute/bandwidth.This keeps the new code minimal and reuses the verified reducer rather than adding a parallel aggregation stack.
Reconciling
--calculate-per-token-loss(one honest axis)--calculate-per-token-lossis not a separate axis — it is thetoken_meanpoint on--loss-aggregation. The two spellings are reconciled inslime_validate_argsso there is exactly one knob and the reported objective is honest:--loss-aggregation=token_meansets--calculate-per-token-loss=True(drives the per-token reducer / normalizer path).--calculate-per-token-losson the defaultsample_meanis relabeled totoken_mean. This is a pure label fix with no behavior change —sample_meanalready returned the per-token reducer when the flag was set, so the loss is byte-identical; only the reported mode name changes. Existing--calculate-per-token-lossrecipes keep working unchanged.--calculate-per-token-losscombined withprompt_meanorconstant(distinct objectives) raises at startup, rather than silently letting the per-token path override the prompt-group / constant denominator.Prefer
--loss-aggregation=token_meanin new recipes.Effective normalization for
constantThe
constantdivisorLis per-token, not per-step. Each mode's reducer returns a step sum that is then averaged by the usual/ step_global_batch_size(identical structure across all four modes). So the effective per-step normalization forconstantis/ (L * step_global_batch_size):Lsets the data-independent per-token scale, and the/ step_global_batch_sizestep average is applied on top exactly as for the other modes. Documented incustomization.mdand the--loss-aggregation-divisorhelp.Scope: pg_loss only (deliberate)
Aggregation applies to
pg_lossonly. The diagnostic metrics —pg_clipfrac,ppo_kl,entropy_loss,kl_loss— keep the default sample-mean reducer so they stay interpretable and comparable across runs (e.g. aconstant/Lmust not crushppo_klby the same factor and make it unreadable). This matches the existing scope of--custom-pg-loss-reducer-function-path, which still takes precedence when set.token_meanis the documented exception: because it reuses--calculate-per-token-loss, it is per-token everywhere.Alternative for reviewers: one could apply the chosen aggregation uniformly to the metrics too (single normalizer for loss and diagnostics). We chose not to, to preserve metric comparability across aggregation modes; this is a one-line change at the call site if the project prefers the uniform convention.
Honesty: prompt_mean absolute scale
prompt_meanweights every prompt group equally — each group's token-weighted mean enters the step sum once, all under the same/ step_global_batch_sizedivisor — so the relative per-prompt weighting (the property DAPO is about) is exact. Its absolute scale differs from a strict1/PDAPO average by a constant factor (P / N, prompts over rollouts), which the learning rate absorbs. Documented indocs/en/get_started/customization.mdand the flag help.Fail-loud guards (no silent degradation)
Every way to misconfigure the new knob is rejected at startup or fails structurally, so no run silently normalizes against the wrong denominator:
constantrequires--loss-aggregation-divisor > 0;None,0, negative, andNaNare all rejected (not (divisor > 0)catchesNaN).--loss-aggregation-divisorset on any non-constantmode raises (it would otherwise be a silent no-op that misleads about normalization).--calculate-per-token-losscombined withprompt_mean/constantraises (distinct objectives — see reconciliation above).prompt_meanrequiresglobal_batch_size % n_samples_per_prompt == 0. Each training step is a contiguous slice ofglobal_batch_sizerollouts; if a prompt group straddled a step boundary its per-group denominator would fragment across optimizer updates, so this is rejected at startup.prompt_meanrequires everySample.group_indexto be set. ANonewould silently collapse unrelated prompts into one denominator (degradingprompt_mean→sample_meanfor that sample), so the convert step raises instead of renumbering it into a singleton group.batch.get("prompt_mask_sums")-checks underprompt_meanand raises if a custom convert path selectedprompt_meanbut dropped the key (rather than degrading to the per-sample mean).get_sum_of_sample_meanzips its per-samplesample_denomswithstrict=True: a caller that supplies a mismatched-lengthsample_denoms(a construction bug) fails loud instead of silently dropping samples.Tests
tests/test_cp_utils.py:constantdivides the masked token-sum byL;prompt_mean's per-group denominator is distinct from the per-rollout (sample_mean) and per-token (token_mean) results on uneven fixtures; theconstant/--calculate-per-token-lossmutual-exclusion guard; CP rank-sum invariance for the newconstantbranch (prompt_meanreuses the per-rollout sample-mean CP path already pinned in the same file); thestrict=Truelength-mismatch guard.tests/test_megatron_argument_validation.py:--loss-aggregation-divisorrejected when missing / non-positive /NaNunderconstant, accepted when positive; rejected on every non-constantmode;token_meanaliases--calculate-per-token-loss; a bare--calculate-per-token-lossreconciles totoken_mean;prompt_mean/constant+--calculate-per-token-lossrejected;prompt_meanrequiresglobal_batch_sizea multiple ofn_samples_per_prompt(and the other modes do not); default leaves the per-token flag off.tests/test_rollout_validation.py: theprompt_mask_sumsbuild in_convert_samples_to_train_data— fails loud on aNonegroup_indexunderprompt_mean, builds the correct per-prompt-group totals, and (forsample_mean/constant/token_mean) neither consultsgroup_indexnor adds theprompt_mask_sumskey, keeping the non-prompt_meanbatch unchanged.All guards are mutation-verified: removing a guard fails its test.
Supersedes #2060.