From ca09065281fef23728146e58faa44c3429d4faae Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Thu, 11 Jun 2026 23:29:47 -0700 Subject: [PATCH 1/2] [megatron] don't re-assert no_sync_func every step with overlap_grad_reduce `train()` sets up `config.no_sync_func` on every step, but `config` is the model config and persists across steps. With `--overlap-grad-reduce` the first step sets it, then the second step trips `assert config.no_sync_func is None` and crashes. Guard the setup with `if config.no_sync_func is None:` so the sync funcs are set once (they are constant, so skipping later steps is a no-op). Fixes #1779 --- slime/backends/megatron_utils/model.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index db6020a94d..dc3dbc88df 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -657,17 +657,17 @@ def train( config.grad_scale_func = optimizer.scale_loss config.timers = None if isinstance(model[0], DDP) and args.overlap_grad_reduce: - assert config.no_sync_func is None, ( - "When overlap_grad_reduce is True, config.no_sync_func must be None; " - "a custom no_sync_func is not supported when overlapping grad-reduce" - ) - config.no_sync_func = [model_chunk.no_sync for model_chunk in model] - if len(model) == 1: - config.no_sync_func = config.no_sync_func[0] - if args.align_grad_reduce: - config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] + # `config` is the model config and persists across steps, so set the sync + # funcs only once — re-running trips `config.no_sync_func is None` on the + # second step (#1779). The funcs are constant, so skipping later is a no-op. + if config.no_sync_func is None: + config.no_sync_func = [model_chunk.no_sync for model_chunk in model] if len(model) == 1: - config.grad_sync_func = config.grad_sync_func[0] + config.no_sync_func = config.no_sync_func[0] + if args.align_grad_reduce: + config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] + if len(model) == 1: + config.grad_sync_func = config.grad_sync_func[0] if args.overlap_param_gather and args.align_param_gather: config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: From f34341c13ae9ed5cd221340ad727a2267cc796a2 Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Sat, 20 Jun 2026 12:23:17 -0700 Subject: [PATCH 2/2] [megatron] extract overlap_grad_reduce setup into a helper + add idempotency test --- slime/backends/megatron_utils/grad_reduce.py | 35 ++++++++ slime/backends/megatron_utils/model.py | 14 +--- tests/test_overlap_grad_reduce.py | 85 ++++++++++++++++++++ 3 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 slime/backends/megatron_utils/grad_reduce.py create mode 100644 tests/test_overlap_grad_reduce.py diff --git a/slime/backends/megatron_utils/grad_reduce.py b/slime/backends/megatron_utils/grad_reduce.py new file mode 100644 index 0000000000..cd9fa5d6cf --- /dev/null +++ b/slime/backends/megatron_utils/grad_reduce.py @@ -0,0 +1,35 @@ +"""Per-step grad-reduce sync-func setup for Megatron training. + +Split out of ``train()`` so the idempotency guard can be unit-tested without a +full Megatron training step (see ``tests/test_overlap_grad_reduce.py``). +""" + +from megatron.core.distributed import DistributedDataParallel as DDP + + +def configure_overlap_grad_reduce(model, config, args): + """Set ``no_sync_func`` / ``grad_sync_func`` on ``config`` for + ``overlap_grad_reduce`` -- exactly once. + + ``config`` is the model config from ``get_model_config(model[0])`` and + persists across ``train()`` calls. The sync funcs are constant, so they are + set only on the first call; re-setting them on a later step would trip the + ``no_sync_func is None`` invariant and crash (#1779). Skipping when they are + already set is a no-op. + + Args: + model: Sequence of DDP-wrapped model chunks. + config: The persistent model ``TransformerConfig``. + args: Megatron args; reads ``overlap_grad_reduce`` and ``align_grad_reduce``. + """ + if not (isinstance(model[0], DDP) and args.overlap_grad_reduce): + return + if config.no_sync_func is not None: + return + config.no_sync_func = [model_chunk.no_sync for model_chunk in model] + if len(model) == 1: + config.no_sync_func = config.no_sync_func[0] + if args.align_grad_reduce: + config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] + if len(model) == 1: + config.grad_sync_func = config.grad_sync_func[0] diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index dc3dbc88df..46331c8e71 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -33,6 +33,7 @@ from .checkpoint import load_checkpoint, save_checkpoint from .cp_utils import reduce_train_step_metrics from .data import DataIterator, get_batch +from .grad_reduce import configure_overlap_grad_reduce from .loss import loss_function from .model_provider import get_model_provider_func @@ -656,18 +657,7 @@ def train( config = get_model_config(model[0]) config.grad_scale_func = optimizer.scale_loss config.timers = None - if isinstance(model[0], DDP) and args.overlap_grad_reduce: - # `config` is the model config and persists across steps, so set the sync - # funcs only once — re-running trips `config.no_sync_func is None` on the - # second step (#1779). The funcs are constant, so skipping later is a no-op. - if config.no_sync_func is None: - config.no_sync_func = [model_chunk.no_sync for model_chunk in model] - if len(model) == 1: - config.no_sync_func = config.no_sync_func[0] - if args.align_grad_reduce: - config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] - if len(model) == 1: - config.grad_sync_func = config.grad_sync_func[0] + configure_overlap_grad_reduce(model, config, args) if args.overlap_param_gather and args.align_param_gather: config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: diff --git a/tests/test_overlap_grad_reduce.py b/tests/test_overlap_grad_reduce.py new file mode 100644 index 0000000000..cdaca02ac0 --- /dev/null +++ b/tests/test_overlap_grad_reduce.py @@ -0,0 +1,85 @@ +"""CPU unit test for ``configure_overlap_grad_reduce`` idempotency (#2066, #1779). + +``config`` persists across ``train()`` calls, so the constant sync funcs must be +set only once -- re-running on a later step previously tripped an +``assert config.no_sync_func is None`` and crashed. The CPU CI image has no +megatron, so we stub the single symbol the helper imports +(``megatron.core.distributed.DistributedDataParallel``) and load the helper file +directly, mirroring ``test_megatron_argument_validation.py``. +""" + +import importlib.util +import sys +import types +from pathlib import Path + +NUM_GPUS = 0 + + +def _load_helper(): + dist_mod = types.ModuleType("megatron.core.distributed") + + class DistributedDataParallel: # megatron DDP, stubbed + pass + + dist_mod.DistributedDataParallel = DistributedDataParallel + core_mod = types.ModuleType("megatron.core") + core_mod.distributed = dist_mod + megatron_mod = types.ModuleType("megatron") + megatron_mod.core = core_mod + sys.modules["megatron"] = megatron_mod + sys.modules["megatron.core"] = core_mod + sys.modules["megatron.core.distributed"] = dist_mod + + path = Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "grad_reduce.py" + spec = importlib.util.spec_from_file_location("slime_grad_reduce_under_test", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module, DistributedDataParallel + + +def _make_model(DDP): + class Chunk(DDP): + def no_sync(self): ... + + def start_grad_sync(self): ... + + return [Chunk()] + + +def test_set_once_then_idempotent(): + helper, DDP = _load_helper() + model = _make_model(DDP) + config = types.SimpleNamespace(no_sync_func=None, grad_sync_func=None) + args = types.SimpleNamespace(overlap_grad_reduce=True, align_grad_reduce=True) + + helper.configure_overlap_grad_reduce(model, config, args) # step 1 + assert config.no_sync_func is not None + assert config.grad_sync_func is not None + no_sync, grad_sync = config.no_sync_func, config.grad_sync_func + + # step 2 must be a no-op, not crash (the #2066 regression) + helper.configure_overlap_grad_reduce(model, config, args) + assert config.no_sync_func is no_sync + assert config.grad_sync_func is grad_sync + + +def test_skipped_when_disabled_or_not_ddp(): + helper, DDP = _load_helper() + config = types.SimpleNamespace(no_sync_func=None, grad_sync_func=None) + + # overlap_grad_reduce off -> untouched + helper.configure_overlap_grad_reduce( + _make_model(DDP), + config, + types.SimpleNamespace(overlap_grad_reduce=False, align_grad_reduce=True), + ) + assert config.no_sync_func is None + + # model is not DDP-wrapped -> untouched + helper.configure_overlap_grad_reduce( + [object()], + config, + types.SimpleNamespace(overlap_grad_reduce=True, align_grad_reduce=True), + ) + assert config.no_sync_func is None