From 69c19edcda99fc70590691bc21cd03cb8211ddce Mon Sep 17 00:00:00 2001 From: HuiyingLi Date: Wed, 17 Jun 2026 05:10:49 -0700 Subject: [PATCH 1/2] feat(magi): honor AttnMaskSpec on the HF attention backend The custom-model magi attn_func reads the active AttnMaskSpec (packing / sliding-window / prefix-tree masks via the flex key), but the HF-registered magi forward did not -- so attn_implementation="magi" silently dropped any non-causal mask while backend.attn="magi" applied it. Worse, a model whose attention dispatches on config._attn_implementation (e.g. the custom Qwen2) with backend.attn="magi" falls back to its default attention and drops the mask with no error. Bring the HF forward to parity: it now reads the per-step AttnMaskSpec stamped on the attention module by _set_attn_spec_on_attention() and builds the flex key from it (cp_size==1; the mask rides on `module`, already in the HF attention signature, so no process-global is read inside the interface). Add a consumption guard: a spec armed for a step but never read by a magi forward raises on the next step, turning the silent non-magi fallback into a loud error. CPU unit tests cover the stamping + guard; the GPU forward is exercised by the FFA parity tests. Co-Authored-By: Claude Opus 4.8 Signed-off-by: HuiyingLi --- .../components/distributed/magi_attn_utils.py | 61 ++++++++++++++++++- .../distributed/test_magi_attn_utils.py | 42 +++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/nemo_automodel/components/distributed/magi_attn_utils.py b/nemo_automodel/components/distributed/magi_attn_utils.py index 3703dfb0f0..445e85502a 100644 --- a/nemo_automodel/components/distributed/magi_attn_utils.py +++ b/nemo_automodel/components/distributed/magi_attn_utils.py @@ -94,6 +94,13 @@ def get_active_cp_group() -> Optional["dist.ProcessGroup"]: # id(cp_group) -> (spec_fingerprint, built_key), so all layers in a step reuse one key. _FLEX_KEY_CACHE: dict = {} +# Consumption guard for the per-step AttnMaskSpec. A spec is "armed" when stamped on the +# attention modules and "consumed" when a magi attention forward actually reads it. If a +# spec is armed but never consumed, the model silently used a non-magi attention (e.g. +# flash_attention_2 / sdpa) and the mask would be dropped -- we raise instead. +_SPEC_ARMED: bool = False +_SPEC_CONSUMED: bool = False + @dataclass class AttnMaskSpec: @@ -358,6 +365,7 @@ def magi_attn_func(q, k, v, **call_kwargs): if spec is None: spec = get_active_attn_spec() if spec is not None: + _mark_attn_spec_consumed() key = _flex_key_for( cp_group, spec, @@ -435,7 +443,23 @@ def magi_attention_forward( **kwargs, ) - if getattr(module, "_magi_self_key", False): + attn_spec = getattr(module, "_magi_attn_spec", None) + if attn_spec is not None: + # Arbitrary AttnSlice mask (packing / sliding-window / prefix-tree) carried on + # the module by :func:`set_attn_spec_on_attention`. The mask rides on ``module`` + # (already in the HF attention signature) rather than a process-global, so the + # HF-registered forward honors it exactly as the custom-model attn_func does -- + # giving attn_implementation="magi" the same mask support as backend.attn="magi". + # cp_size==1 (no dispatch); the flex key encodes the full mask. + magi_attn_key = _flex_key_for( + cp_group, + attn_spec, + num_heads_q=query.shape[1], + num_heads_kv=key.shape[1], + head_dim=query.shape[3], + ) + _mark_attn_spec_consumed() + elif getattr(module, "_magi_self_key", False): # VLM (cp_size==1) path: the post-image-merge LM sequence length is only # known here (query dim 2: [b, nh, s, hd]) and may differ from input_ids # length. Build a no-dispatch causal key matching the actual q length. @@ -505,6 +529,41 @@ def _set_cp_group_on_attention(model, cp_group) -> None: module.cp_group = cp_group +def _set_attn_spec_on_attention(model, spec: Optional["AttnMaskSpec"]) -> None: + """Stamp the per-step :class:`AttnMaskSpec` on every attention sub-module. + + The HF-registered magi forward reads the spec from ``module`` (a carrier already in + the attention call signature) rather than a process-global, so the mask is scoped to + this model's attention modules and the HF interface stays signature-clean. Call every + step with ``spec`` (or ``None`` to clear). Arms the consumption guard: if the previous + step's spec was never consumed by a magi forward, the model silently fell back to a + non-magi attention and the mask was dropped -- raise rather than train on the wrong mask. + + Args: + model: the (possibly FSDP-wrapped) model whose attention modules to stamp. + spec: the mask spec for this step, or ``None`` for plain causal / no mask. + """ + global _SPEC_ARMED, _SPEC_CONSUMED + if _SPEC_ARMED and not _SPEC_CONSUMED: + raise RuntimeError( + "A magi AttnMaskSpec was activated for the previous step but no magi attention " + "consumed it: the model is not routing attention through magi (it silently used " + "its default attention, e.g. flash_attention_2). Set model.attn_implementation='magi' " + "on an HF-style model, or model.backend.attn='magi' on a factory-based custom model." + ) + for module in model.modules(): + if "Attention" in type(module).__name__: + module._magi_attn_spec = spec + _SPEC_ARMED = spec is not None + _SPEC_CONSUMED = False + + +def _mark_attn_spec_consumed() -> None: + """Record that a magi attention forward read the active :class:`AttnMaskSpec`.""" + global _SPEC_CONSUMED + _SPEC_CONSUMED = True + + def magi_prepare_batch( # pragma: no cover - requires GPU + magi_attention model, batch: dict, diff --git a/tests/unit_tests/distributed/test_magi_attn_utils.py b/tests/unit_tests/distributed/test_magi_attn_utils.py index 99c619d686..281c636f8b 100644 --- a/tests/unit_tests/distributed/test_magi_attn_utils.py +++ b/tests/unit_tests/distributed/test_magi_attn_utils.py @@ -376,3 +376,45 @@ def test_make_magi_attn_func_matches_sdpa_causal(): mu.set_active_cp_group(None) if dist.is_initialized(): dist.destroy_process_group() + + +@pytest.fixture(autouse=True) +def _reset_spec_guard(): + """Reset the module-global consumption-guard state around each test.""" + mu._SPEC_ARMED = False + mu._SPEC_CONSUMED = False + yield + mu._SPEC_ARMED = False + mu._SPEC_CONSUMED = False + + +class TestAttnSpecOnModule: + def test_stamps_only_attention_modules(self): + model = _VLM() + spec = AttnMaskSpec.causal(8) + mu._set_attn_spec_on_attention(model, spec) + # language-backbone attention is stamped; the (Linear) vision tower is not. + assert model.language_model.self_attn._magi_attn_spec is spec + assert not hasattr(model.visual, "_magi_attn_spec") + # a non-None spec arms the guard, not yet consumed. + assert mu._SPEC_ARMED is True + assert mu._SPEC_CONSUMED is False + + def test_none_clears_and_does_not_arm(self): + model = _LM() + mu._set_attn_spec_on_attention(model, None) + assert model.self_attn._magi_attn_spec is None + assert mu._SPEC_ARMED is False + + def test_guard_raises_when_armed_spec_not_consumed(self): + model = _LM() + mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) # arm, never consume + with pytest.raises(RuntimeError, match="no magi attention consumed it"): + mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) + + def test_guard_passes_when_consumed(self): + model = _LM() + mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) + mu._mark_attn_spec_consumed() # a magi forward read it + mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) # next step must not raise + assert mu._SPEC_ARMED is True From 81ef1129eb693811ed713019ef633eff01ee6b43 Mon Sep 17 00:00:00 2001 From: HuiyingLi Date: Wed, 17 Jun 2026 22:09:47 -0700 Subject: [PATCH 2/2] test(magi): cover HF forward AttnMaskSpec key selection Add CPU tests (magi_attention stubbed) that the registered "magi" HF forward builds the flex key and marks the spec consumed when _magi_attn_spec is on the module, and falls back to the dispatched key otherwise. Co-Authored-By: Claude Opus 4.8 Signed-off-by: HuiyingLi --- .../distributed/test_magi_attn_utils.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/unit_tests/distributed/test_magi_attn_utils.py b/tests/unit_tests/distributed/test_magi_attn_utils.py index 281c636f8b..a6d95db9ae 100644 --- a/tests/unit_tests/distributed/test_magi_attn_utils.py +++ b/tests/unit_tests/distributed/test_magi_attn_utils.py @@ -418,3 +418,54 @@ def test_guard_passes_when_consumed(self): mu._mark_attn_spec_consumed() # a magi forward read it mu._set_attn_spec_on_attention(model, AttnMaskSpec.causal(8)) # next step must not raise assert mu._SPEC_ARMED is True + + +class TestHFForwardKeySelection: + """The registered "magi" HF forward picks its dist key from module state. + + magi_attention is stubbed so register_magi_attention() and the forward run on + CPU; the key builders are spied to assert which branch is taken. + """ + + def _register_with_stub(self, monkeypatch): + import sys + import types + + api = types.ModuleType("magi_attention.api") + api.calc_attn = lambda q, k, v, key, **kw: (q,) # echo q; shape-preserving + api.get_most_recent_key = lambda cp_group: "RECENT_KEY" + monkeypatch.setitem(sys.modules, "magi_attention", types.ModuleType("magi_attention")) + monkeypatch.setitem(sys.modules, "magi_attention.api", api) + monkeypatch.setattr(mu, "_MAGI_REGISTERED", False) + + picked = {} + monkeypatch.setattr(mu, "_flex_key_for", lambda *a, **k: picked.setdefault("flex", 0) or picked.update(flex=1)) + monkeypatch.setattr(mu, "_self_key_for", lambda *a, **k: picked.update(self_key=1)) + + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + mu.register_magi_attention() + return ALL_ATTENTION_FUNCTIONS["magi"], picked + + def test_module_spec_selects_flex_key_and_marks_consumed(self, monkeypatch): + fwd, picked = self._register_with_stub(monkeypatch) + module = nn.Module() + module.cp_group = _FakeGroup(1) + module._magi_attn_spec = AttnMaskSpec.causal(3) + q = torch.randn(1, 2, 3, 4) # [b, nh, s, hd] + out, _ = fwd(module, q, q, q, scaling=0.5) + assert picked.get("flex") == 1 + assert "self_key" not in picked + assert mu._SPEC_CONSUMED is True # forward consumed the armed spec + assert tuple(out.shape) == (1, 3, 8) # [b, s, nh*hd] + + def test_no_spec_falls_back_to_dispatched_key(self, monkeypatch): + fwd, picked = self._register_with_stub(monkeypatch) + module = nn.Module() + module.cp_group = _FakeGroup(2) + # no _magi_attn_spec, no _magi_self_key -> get_most_recent_key path + q = torch.randn(1, 2, 3, 4) + fwd(module, q, q, q, scaling=0.5) + assert "flex" not in picked + assert "self_key" not in picked + assert mu._SPEC_CONSUMED is False