Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion nemo_automodel/components/distributed/magi_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def get_active_cp_group() -> Optional["dist.ProcessGroup"]:
# id(cp_group) -> (spec_fingerprint, built_key), so all layers in a step reuse one key.
_FLEX_KEY_CACHE: dict = {}

# Consumption guard for the per-step AttnMaskSpec. A spec is "armed" when stamped on the
# attention modules and "consumed" when a magi attention forward actually reads it. If a
# spec is armed but never consumed, the model silently used a non-magi attention (e.g.
# flash_attention_2 / sdpa) and the mask would be dropped -- we raise instead.
_SPEC_ARMED: bool = False
_SPEC_CONSUMED: bool = False


@dataclass
class AttnMaskSpec:
Expand Down Expand Up @@ -358,6 +365,7 @@ def magi_attn_func(q, k, v, **call_kwargs):
if spec is None:
spec = get_active_attn_spec()
if spec is not None:
_mark_attn_spec_consumed()
key = _flex_key_for(
cp_group,
spec,
Expand Down Expand Up @@ -435,7 +443,23 @@ def magi_attention_forward(
**kwargs,
)

if getattr(module, "_magi_self_key", False):
attn_spec = getattr(module, "_magi_attn_spec", None)
if attn_spec is not None:
# Arbitrary AttnSlice mask (packing / sliding-window / prefix-tree) carried on
# the module by :func:`set_attn_spec_on_attention`. The mask rides on ``module``
# (already in the HF attention signature) rather than a process-global, so the
# HF-registered forward honors it exactly as the custom-model attn_func does --
# giving attn_implementation="magi" the same mask support as backend.attn="magi".
# cp_size==1 (no dispatch); the flex key encodes the full mask.
magi_attn_key = _flex_key_for(
cp_group,
attn_spec,
num_heads_q=query.shape[1],
num_heads_kv=key.shape[1],
head_dim=query.shape[3],
)
_mark_attn_spec_consumed()
elif getattr(module, "_magi_self_key", False):
# VLM (cp_size==1) path: the post-image-merge LM sequence length is only
# known here (query dim 2: [b, nh, s, hd]) and may differ from input_ids
# length. Build a no-dispatch causal key matching the actual q length.
Expand Down Expand Up @@ -505,6 +529,41 @@ def _set_cp_group_on_attention(model, cp_group) -> None:
module.cp_group = cp_group


def _set_attn_spec_on_attention(model, spec: Optional["AttnMaskSpec"]) -> None:
"""Stamp the per-step :class:`AttnMaskSpec` on every attention sub-module.

The HF-registered magi forward reads the spec from ``module`` (a carrier already in
the attention call signature) rather than a process-global, so the mask is scoped to
this model's attention modules and the HF interface stays signature-clean. Call every
step with ``spec`` (or ``None`` to clear). Arms the consumption guard: if the previous
step's spec was never consumed by a magi forward, the model silently fell back to a
non-magi attention and the mask was dropped -- raise rather than train on the wrong mask.

Args:
model: the (possibly FSDP-wrapped) model whose attention modules to stamp.
spec: the mask spec for this step, or ``None`` for plain causal / no mask.
"""
global _SPEC_ARMED, _SPEC_CONSUMED
if _SPEC_ARMED and not _SPEC_CONSUMED:
raise RuntimeError(
"A magi AttnMaskSpec was activated for the previous step but no magi attention "
"consumed it: the model is not routing attention through magi (it silently used "
"its default attention, e.g. flash_attention_2). Set model.attn_implementation='magi' "
"on an HF-style model, or model.backend.attn='magi' on a factory-based custom model."
)
for module in model.modules():
if "Attention" in type(module).__name__:
module._magi_attn_spec = spec
_SPEC_ARMED = spec is not None
_SPEC_CONSUMED = False


def _mark_attn_spec_consumed() -> None:
"""Record that a magi attention forward read the active :class:`AttnMaskSpec`."""
global _SPEC_CONSUMED
_SPEC_CONSUMED = True


def magi_prepare_batch( # pragma: no cover - requires GPU + magi_attention
model,
batch: dict,
Expand Down
93 changes: 93 additions & 0 deletions tests/unit_tests/distributed/test_magi_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,96 @@ def test_make_magi_attn_func_matches_sdpa_causal():
mu.set_active_cp_group(None)
if dist.is_initialized():
dist.destroy_process_group()


@pytest.fixture(autouse=True)
def _reset_spec_guard():
"""Reset the module-global consumption-guard state around each test."""
mu._SPEC_ARMED = False
mu._SPEC_CONSUMED = False
yield
mu._SPEC_ARMED = False
mu._SPEC_CONSUMED = False


class TestAttnSpecOnModule:
def test_stamps_only_attention_modules(self):
model = _VLM()
spec = AttnMaskSpec.causal(8)
mu._set_attn_spec_on_attention(model, spec)
# language-backbone attention is stamped; the (Linear) vision tower is not.
assert model.language_model.self_attn._magi_attn_spec is spec
assert not hasattr(model.visual, "_magi_attn_spec")
# a non-None spec arms the guard, not yet consumed.
assert mu._SPEC_ARMED is True
assert mu._SPEC_CONSUMED is False

def test_none_clears_and_does_not_arm(self):
model = _LM()
mu._set_attn_spec_on_attention(model, None)
assert model.self_attn._magi_attn_spec is None
assert mu._SPEC_ARMED is False

def test_guard_raises_when_armed_spec_not_consumed(self):
model = _LM()
mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) # arm, never consume
with pytest.raises(RuntimeError, match="no magi attention consumed it"):
mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8))

def test_guard_passes_when_consumed(self):
model = _LM()
mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8))
mu._mark_attn_spec_consumed() # a magi forward read it
mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) # next step must not raise
assert mu._SPEC_ARMED is True


class TestHFForwardKeySelection:
"""The registered "magi" HF forward picks its dist key from module state.

magi_attention is stubbed so register_magi_attention() and the forward run on
CPU; the key builders are spied to assert which branch is taken.
"""

def _register_with_stub(self, monkeypatch):
import sys
import types

api = types.ModuleType("magi_attention.api")
api.calc_attn = lambda q, k, v, key, **kw: (q,) # echo q; shape-preserving
api.get_most_recent_key = lambda cp_group: "RECENT_KEY"
monkeypatch.setitem(sys.modules, "magi_attention", types.ModuleType("magi_attention"))
monkeypatch.setitem(sys.modules, "magi_attention.api", api)
monkeypatch.setattr(mu, "_MAGI_REGISTERED", False)

picked = {}
monkeypatch.setattr(mu, "_flex_key_for", lambda *a, **k: picked.setdefault("flex", 0) or picked.update(flex=1))
monkeypatch.setattr(mu, "_self_key_for", lambda *a, **k: picked.update(self_key=1))

from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

mu.register_magi_attention()
return ALL_ATTENTION_FUNCTIONS["magi"], picked

def test_module_spec_selects_flex_key_and_marks_consumed(self, monkeypatch):
fwd, picked = self._register_with_stub(monkeypatch)
module = nn.Module()
module.cp_group = _FakeGroup(1)
module._magi_attn_spec = AttnMaskSpec.causal(3)
q = torch.randn(1, 2, 3, 4) # [b, nh, s, hd]
out, _ = fwd(module, q, q, q, scaling=0.5)
assert picked.get("flex") == 1
assert "self_key" not in picked
assert mu._SPEC_CONSUMED is True # forward consumed the armed spec
assert tuple(out.shape) == (1, 3, 8) # [b, s, nh*hd]

def test_no_spec_falls_back_to_dispatched_key(self, monkeypatch):
fwd, picked = self._register_with_stub(monkeypatch)
module = nn.Module()
module.cp_group = _FakeGroup(2)
# no _magi_attn_spec, no _magi_self_key -> get_most_recent_key path
q = torch.randn(1, 2, 3, 4)
fwd(module, q, q, q, scaling=0.5)
assert "flex" not in picked
assert "self_key" not in picked
assert mu._SPEC_CONSUMED is False
Loading