From 1b70bf31ae5f5d80bcc2df982831c4e4b731a1cf Mon Sep 17 00:00:00 2001 From: EazyReal <8047065+EazyReal@users.noreply.github.com> Date: Mon, 15 Jun 2026 20:43:20 +0000 Subject: [PATCH] feat(rl): composable off-policy importance-sampling correction Expose the current grad-carrying log-probs to the policy-loss TIS hook as `cur_log_probs`, and add `off_policy_is_function` (in ppo_utils, next to compute_policy_loss/compute_cispo_loss) -- a truncated-IS correction between the *current* policy and the *actual rollout generator*: the (detached) weight is `clip(pi_theta / pi_rollout)` against the real rollout logprob, so one weight corrects both the train/inference mismatch and async (multi-version) staleness. The existing TIS hook only had pi_theta_old / pi_rollout, which equals this only in the single-update-per-rollout limit. On a plain REINFORCE base (`--advantage-estimator reinforce`) this reproduces the CISPO surrogate expressed as a correction rather than the dedicated `compute_cispo_loss` estimator. Existing corrections ignore the new kwarg via **kwargs. Co-Authored-By: Claude Opus 4.8 (1M context) --- slime/backends/megatron_utils/loss.py | 6 ++ slime/utils/ppo_utils.py | 32 ++++++++++ tests/test_off_policy_is.py | 84 +++++++++++++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 tests/test_off_policy_is.py diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a456939e74..d4816aa1fa 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -846,6 +846,9 @@ def policy_loss_function( ) log_probs = log_probs_and_entropy["log_probs"] + # Current pi_theta (grad-carrying), captured before the cat below; passed to TIS-hook + # corrections so they can form pi_theta / pi_rollout (see off_policy_is_function). + cur_log_probs_list = log_probs if not args.use_rollout_logprobs and not old_log_probs: old_log_probs = [log_prob.detach() for log_prob in log_probs] train_log_probs_for_tis = batch.get("log_probs") @@ -919,9 +922,12 @@ def policy_loss_function( assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" ois = (-ppo_kl).exp() + # Pass cur_log_probs (current pi_theta, grad-carrying) so corrections can form + # pi_theta/pi_rollout, not just the frozen pi_theta_old/pi_rollout of vanilla TIS. tis_kwargs = { "args": args, "pg_loss": pg_loss, + "cur_log_probs": cur_log_probs_list, "train_log_probs": train_log_probs_for_tis, "rollout_log_probs": batch["rollout_log_probs"], "loss_masks": batch["loss_masks"], diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 327dec2de6..1c0befa1b5 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -2,6 +2,7 @@ # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py from argparse import Namespace +from typing import Any import torch import torch.distributed as dist @@ -171,6 +172,37 @@ def compute_cispo_loss( return pg_losses, clipfrac +def off_policy_is_function( + args: Namespace, + *, + pg_loss: torch.Tensor, + cur_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + **kwargs: Any, +) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: + """Off-policy truncated IS (TIS hook): like ``vanilla_tis_function`` but with the + *current* policy in the numerator instead of the old recompute, so the (detached) + weight is ``clip(pi_theta / pi_rollout)`` against the actual rollout logprob -- one + weight that corrects both the train/inference mismatch and async (multi-version) + staleness. Composed with ``--advantage-estimator reinforce`` it is the CISPO surrogate + (https://arxiv.org/abs/2506.13585), expressed as a correction rather than the dedicated + ``compute_cispo_loss`` estimator; ``--eps-clip 1.0`` gives canonical single-sided clipping. + Same ``(pg_loss, loss_masks, metrics)`` contract; ``loss_masks`` unchanged. + """ + cur = torch.cat([lp.detach() for lp in cur_log_probs], dim=0) + rollout = torch.cat(rollout_log_probs, dim=0) + ratio = torch.exp(cur - rollout) + is_weights = torch.clamp(ratio, min=1.0 - args.eps_clip, max=1.0 + args.eps_clip_high) + is_clipfrac = (is_weights != ratio).float() + metrics = { + "is_weight": ratio.clone().detach(), + "is_clipfrac": is_clipfrac.clone().detach(), + } + pg_loss = pg_loss * is_weights + return pg_loss, loss_masks, metrics + + def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None): # TODO: when megatron is not installed, fall back to naive implementation from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy diff --git a/tests/test_off_policy_is.py b/tests/test_off_policy_is.py new file mode 100644 index 0000000000..e249dd47df --- /dev/null +++ b/tests/test_off_policy_is.py @@ -0,0 +1,84 @@ +"""Unit tests for the ``off_policy_is_function`` importance-sampling correction +(slime/utils/ppo_utils.py). + +It is truncated IS between the *current* policy and the *actual rollout generator*: +the (detached) weight is ``clip(pi_theta / pi_rollout)``. On a plain REINFORCE base +``-A * log pi`` it reproduces the CISPO surrogate (https://arxiv.org/abs/2506.13585). + +Pure-torch (no megatron), like tests/test_chunked_gae.py; runs on CPU. NUM_GPUS = 0 +selects the CPU runner in the changed-test CI matrix; the __main__ block lets CI run +it as a script. (The hook wiring in loss.py that supplies cur_log_probs imports +megatron and is exercised in the GPU CI suites.) +""" + +from argparse import Namespace + +import torch + +from slime.utils.ppo_utils import off_policy_is_function + +# CPU-only test: selects the 0-GPU runner in the changed-test CI matrix. +NUM_GPUS = 0 + + +def test_off_policy_is_function_clips_weight_and_passes_masks_through(): + # ratio = exp(cur - rollout): ln(2) -> 2 -> clamp 1.2; ln(0.5) -> 0.5 -> 0.8; 0 -> 1.0 + cur = torch.tensor([1.0, 1.0, 1.0]) + rollout = cur - torch.tensor([2.0, 0.5, 1.0]).log() + pg_loss = torch.tensor([1.0, 1.0, 1.0]) + loss_masks = [torch.ones(3)] + args = Namespace(eps_clip=0.2, eps_clip_high=0.2) + + out_loss, out_masks, metrics = off_policy_is_function( + args, pg_loss=pg_loss, cur_log_probs=[cur], rollout_log_probs=[rollout], loss_masks=loss_masks + ) + + expected_w = torch.tensor([1.2, 0.8, 1.0]) + assert torch.allclose(out_loss, pg_loss * expected_w) + assert torch.allclose(metrics["is_clipfrac"], torch.tensor([1.0, 1.0, 0.0])) + assert out_masks is loss_masks # no rejection-sampling masking + + +def test_off_policy_is_on_reinforce_base_equals_cispo_surrogate(): + # On a plain REINFORCE base (-A * log pi), off_policy_is_function reproduces the + # CISPO surrogate exactly, with gradient flowing ONLY through log_probs. + advantages = torch.tensor([2.0, -1.0, 0.5, 1.5]) + rollout = torch.tensor([-0.5, -0.2, -0.9, -0.3]) # behavior policy mu (frozen) + log_probs = torch.tensor([-0.1, -0.4, -0.3, -0.8], requires_grad=True) + args = Namespace(eps_clip=0.2, eps_clip_high=0.2) + + pg_loss = -advantages * log_probs # plain REINFORCE base + pg_loss, _, _ = off_policy_is_function( + args, pg_loss=pg_loss, cur_log_probs=[log_probs], rollout_log_probs=[rollout], loss_masks=[torch.ones(4)] + ) + + ratio = torch.exp(log_probs.detach() - rollout) # pi_theta / pi_rollout + clipped = ratio.clamp(1 - args.eps_clip, 1 + args.eps_clip_high) + assert torch.allclose(pg_loss, -clipped * advantages * log_probs.detach()) + + pg_loss.sum().backward() + # d/d log_probs [ -clip(ratio).detach() * A * log_probs ] = -clip(ratio) * A + assert torch.allclose(log_probs.grad, -clipped * advantages) + + +def test_off_policy_is_single_sided_when_eps_clip_one(): + # Canonical CISPO: eps_clip=1.0 disables the lower bound (ratio >= 0 never clipped low). + cur = torch.tensor([0.0, 0.0]) + rollout = cur - torch.tensor([10.0, 0.01]).log() # ratios 10.0 (high) and ~0.01 (very low) + pg_loss = torch.tensor([1.0, 1.0]) + args = Namespace(eps_clip=1.0, eps_clip_high=4.0) + + _, _, metrics = off_policy_is_function( + args, pg_loss=pg_loss, cur_log_probs=[cur], rollout_log_probs=[rollout], loss_masks=[torch.ones(2)] + ) + + # high ratio 10.0 > 1+eps_clip_high=5.0 clipped; low ratio ~0.01 >= 1-eps_clip=0.0 NOT clipped + assert torch.allclose(metrics["is_clipfrac"], torch.tensor([1.0, 0.0])) + + +if __name__ == "__main__": + for name, fn in sorted(globals().items()): + if name.startswith("test_") and callable(fn): + fn() + print(f"PASSED {name}") + print("OK")