Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,9 @@ def policy_loss_function(
)

log_probs = log_probs_and_entropy["log_probs"]
# Current pi_theta (grad-carrying), captured before the cat below; passed to TIS-hook
# corrections so they can form pi_theta / pi_rollout (see off_policy_is_function).
cur_log_probs_list = log_probs
if not args.use_rollout_logprobs and not old_log_probs:
old_log_probs = [log_prob.detach() for log_prob in log_probs]
train_log_probs_for_tis = batch.get("log_probs")
Expand Down Expand Up @@ -919,9 +922,12 @@ def policy_loss_function(
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"

ois = (-ppo_kl).exp()
# Pass cur_log_probs (current pi_theta, grad-carrying) so corrections can form
# pi_theta/pi_rollout, not just the frozen pi_theta_old/pi_rollout of vanilla TIS.
tis_kwargs = {
"args": args,
"pg_loss": pg_loss,
"cur_log_probs": cur_log_probs_list,
"train_log_probs": train_log_probs_for_tis,
"rollout_log_probs": batch["rollout_log_probs"],
"loss_masks": batch["loss_masks"],
Expand Down
32 changes: 32 additions & 0 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py

from argparse import Namespace
from typing import Any

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -171,6 +172,37 @@ def compute_cispo_loss(
return pg_losses, clipfrac


def off_policy_is_function(
args: Namespace,
*,
pg_loss: torch.Tensor,
cur_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs: Any,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Off-policy truncated IS (TIS hook): like ``vanilla_tis_function`` but with the
*current* policy in the numerator instead of the old recompute, so the (detached)
weight is ``clip(pi_theta / pi_rollout)`` against the actual rollout logprob -- one
weight that corrects both the train/inference mismatch and async (multi-version)
staleness. Composed with ``--advantage-estimator reinforce`` it is the CISPO surrogate
(https://arxiv.org/abs/2506.13585), expressed as a correction rather than the dedicated
``compute_cispo_loss`` estimator; ``--eps-clip 1.0`` gives canonical single-sided clipping.
Same ``(pg_loss, loss_masks, metrics)`` contract; ``loss_masks`` unchanged.
"""
cur = torch.cat([lp.detach() for lp in cur_log_probs], dim=0)
rollout = torch.cat(rollout_log_probs, dim=0)
ratio = torch.exp(cur - rollout)
is_weights = torch.clamp(ratio, min=1.0 - args.eps_clip, max=1.0 + args.eps_clip_high)
is_clipfrac = (is_weights != ratio).float()
metrics = {
"is_weight": ratio.clone().detach(),
"is_clipfrac": is_clipfrac.clone().detach(),
}
pg_loss = pg_loss * is_weights
return pg_loss, loss_masks, metrics


def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None):
# TODO: when megatron is not installed, fall back to naive implementation
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
Expand Down
84 changes: 84 additions & 0 deletions tests/test_off_policy_is.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Unit tests for the ``off_policy_is_function`` importance-sampling correction
(slime/utils/ppo_utils.py).

It is truncated IS between the *current* policy and the *actual rollout generator*:
the (detached) weight is ``clip(pi_theta / pi_rollout)``. On a plain REINFORCE base
``-A * log pi`` it reproduces the CISPO surrogate (https://arxiv.org/abs/2506.13585).

Pure-torch (no megatron), like tests/test_chunked_gae.py; runs on CPU. NUM_GPUS = 0
selects the CPU runner in the changed-test CI matrix; the __main__ block lets CI run
it as a script. (The hook wiring in loss.py that supplies cur_log_probs imports
megatron and is exercised in the GPU CI suites.)
"""

from argparse import Namespace

import torch

from slime.utils.ppo_utils import off_policy_is_function

# CPU-only test: selects the 0-GPU runner in the changed-test CI matrix.
NUM_GPUS = 0


def test_off_policy_is_function_clips_weight_and_passes_masks_through():
# ratio = exp(cur - rollout): ln(2) -> 2 -> clamp 1.2; ln(0.5) -> 0.5 -> 0.8; 0 -> 1.0
cur = torch.tensor([1.0, 1.0, 1.0])
rollout = cur - torch.tensor([2.0, 0.5, 1.0]).log()
pg_loss = torch.tensor([1.0, 1.0, 1.0])
loss_masks = [torch.ones(3)]
args = Namespace(eps_clip=0.2, eps_clip_high=0.2)

out_loss, out_masks, metrics = off_policy_is_function(
args, pg_loss=pg_loss, cur_log_probs=[cur], rollout_log_probs=[rollout], loss_masks=loss_masks
)

expected_w = torch.tensor([1.2, 0.8, 1.0])
assert torch.allclose(out_loss, pg_loss * expected_w)
assert torch.allclose(metrics["is_clipfrac"], torch.tensor([1.0, 1.0, 0.0]))
assert out_masks is loss_masks # no rejection-sampling masking


def test_off_policy_is_on_reinforce_base_equals_cispo_surrogate():
# On a plain REINFORCE base (-A * log pi), off_policy_is_function reproduces the
# CISPO surrogate exactly, with gradient flowing ONLY through log_probs.
advantages = torch.tensor([2.0, -1.0, 0.5, 1.5])
rollout = torch.tensor([-0.5, -0.2, -0.9, -0.3]) # behavior policy mu (frozen)
log_probs = torch.tensor([-0.1, -0.4, -0.3, -0.8], requires_grad=True)
args = Namespace(eps_clip=0.2, eps_clip_high=0.2)

pg_loss = -advantages * log_probs # plain REINFORCE base
pg_loss, _, _ = off_policy_is_function(
args, pg_loss=pg_loss, cur_log_probs=[log_probs], rollout_log_probs=[rollout], loss_masks=[torch.ones(4)]
)

ratio = torch.exp(log_probs.detach() - rollout) # pi_theta / pi_rollout
clipped = ratio.clamp(1 - args.eps_clip, 1 + args.eps_clip_high)
assert torch.allclose(pg_loss, -clipped * advantages * log_probs.detach())

pg_loss.sum().backward()
# d/d log_probs [ -clip(ratio).detach() * A * log_probs ] = -clip(ratio) * A
assert torch.allclose(log_probs.grad, -clipped * advantages)


def test_off_policy_is_single_sided_when_eps_clip_one():
# Canonical CISPO: eps_clip=1.0 disables the lower bound (ratio >= 0 never clipped low).
cur = torch.tensor([0.0, 0.0])
rollout = cur - torch.tensor([10.0, 0.01]).log() # ratios 10.0 (high) and ~0.01 (very low)
pg_loss = torch.tensor([1.0, 1.0])
args = Namespace(eps_clip=1.0, eps_clip_high=4.0)

_, _, metrics = off_policy_is_function(
args, pg_loss=pg_loss, cur_log_probs=[cur], rollout_log_probs=[rollout], loss_masks=[torch.ones(2)]
)

# high ratio 10.0 > 1+eps_clip_high=5.0 clipped; low ratio ~0.01 >= 1-eps_clip=0.0 NOT clipped
assert torch.allclose(metrics["is_clipfrac"], torch.tensor([1.0, 0.0]))


if __name__ == "__main__":
for name, fn in sorted(globals().items()):
if name.startswith("test_") and callable(fn):
fn()
print(f"PASSED {name}")
print("OK")
Loading