Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions slime/backends/megatron_utils/grad_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Per-step grad-reduce sync-func setup for Megatron training.

Split out of ``train()`` so the idempotency guard can be unit-tested without a
full Megatron training step (see ``tests/test_overlap_grad_reduce.py``).
"""

from megatron.core.distributed import DistributedDataParallel as DDP


def configure_overlap_grad_reduce(model, config, args):
"""Set ``no_sync_func`` / ``grad_sync_func`` on ``config`` for
``overlap_grad_reduce`` -- exactly once.

``config`` is the model config from ``get_model_config(model[0])`` and
persists across ``train()`` calls. The sync funcs are constant, so they are
set only on the first call; re-setting them on a later step would trip the
``no_sync_func is None`` invariant and crash (#1779). Skipping when they are
already set is a no-op.

Args:
model: Sequence of DDP-wrapped model chunks.
config: The persistent model ``TransformerConfig``.
args: Megatron args; reads ``overlap_grad_reduce`` and ``align_grad_reduce``.
"""
if not (isinstance(model[0], DDP) and args.overlap_grad_reduce):
return
if config.no_sync_func is not None:
return
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
if args.align_grad_reduce:
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
if len(model) == 1:
config.grad_sync_func = config.grad_sync_func[0]
14 changes: 2 additions & 12 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .checkpoint import load_checkpoint, save_checkpoint
from .cp_utils import reduce_train_step_metrics
from .data import DataIterator, get_batch
from .grad_reduce import configure_overlap_grad_reduce
from .loss import ROLLOUT_TOP_P_TOKEN_KEYS, get_rollout_top_p_logprob_kwargs, loss_function
from .model_provider import get_model_provider_func

Expand Down Expand Up @@ -669,18 +670,7 @@ def train(
config = get_model_config(model[0])
config.grad_scale_func = optimizer.scale_loss
config.timers = None
if isinstance(model[0], DDP) and args.overlap_grad_reduce:
assert config.no_sync_func is None, (
"When overlap_grad_reduce is True, config.no_sync_func must be None; "
"a custom no_sync_func is not supported when overlapping grad-reduce"
)
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
if args.align_grad_reduce:
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
if len(model) == 1:
config.grad_sync_func = config.grad_sync_func[0]
configure_overlap_grad_reduce(model, config, args)
if args.overlap_param_gather and args.align_param_gather:
config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
if len(model) == 1:
Expand Down
85 changes: 85 additions & 0 deletions tests/test_overlap_grad_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""CPU unit test for ``configure_overlap_grad_reduce`` idempotency (#2066, #1779).

``config`` persists across ``train()`` calls, so the constant sync funcs must be
set only once -- re-running on a later step previously tripped an
``assert config.no_sync_func is None`` and crashed. The CPU CI image has no
megatron, so we stub the single symbol the helper imports
(``megatron.core.distributed.DistributedDataParallel``) and load the helper file
directly, mirroring ``test_megatron_argument_validation.py``.
"""

import importlib.util
import sys
import types
from pathlib import Path

NUM_GPUS = 0


def _load_helper():
dist_mod = types.ModuleType("megatron.core.distributed")

class DistributedDataParallel: # megatron DDP, stubbed
pass

dist_mod.DistributedDataParallel = DistributedDataParallel
core_mod = types.ModuleType("megatron.core")
core_mod.distributed = dist_mod
megatron_mod = types.ModuleType("megatron")
megatron_mod.core = core_mod
sys.modules["megatron"] = megatron_mod
sys.modules["megatron.core"] = core_mod
sys.modules["megatron.core.distributed"] = dist_mod

path = Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "grad_reduce.py"
spec = importlib.util.spec_from_file_location("slime_grad_reduce_under_test", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module, DistributedDataParallel


def _make_model(DDP):
class Chunk(DDP):
def no_sync(self): ...

def start_grad_sync(self): ...

return [Chunk()]


def test_set_once_then_idempotent():
helper, DDP = _load_helper()
model = _make_model(DDP)
config = types.SimpleNamespace(no_sync_func=None, grad_sync_func=None)
args = types.SimpleNamespace(overlap_grad_reduce=True, align_grad_reduce=True)

helper.configure_overlap_grad_reduce(model, config, args) # step 1
assert config.no_sync_func is not None
assert config.grad_sync_func is not None
no_sync, grad_sync = config.no_sync_func, config.grad_sync_func

# step 2 must be a no-op, not crash (the #2066 regression)
helper.configure_overlap_grad_reduce(model, config, args)
assert config.no_sync_func is no_sync
assert config.grad_sync_func is grad_sync


def test_skipped_when_disabled_or_not_ddp():
helper, DDP = _load_helper()
config = types.SimpleNamespace(no_sync_func=None, grad_sync_func=None)

# overlap_grad_reduce off -> untouched
helper.configure_overlap_grad_reduce(
_make_model(DDP),
config,
types.SimpleNamespace(overlap_grad_reduce=False, align_grad_reduce=True),
)
assert config.no_sync_func is None

# model is not DDP-wrapped -> untouched
helper.configure_overlap_grad_reduce(
[object()],
config,
types.SimpleNamespace(overlap_grad_reduce=True, align_grad_reduce=True),
)
assert config.no_sync_func is None
Loading