From 4cb305191973e6af5e18bfe7de2612ecaad74111 Mon Sep 17 00:00:00 2001 From: continuousml Date: Sun, 28 Jun 2026 18:33:47 -0700 Subject: [PATCH] Fix TPU Splash local sliding window size --- src/maxtext/layers/attention_op.py | 2 +- tests/unit/attention_test.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index badcdb66ea..09c467706a 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1286,7 +1286,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): raise ValueError("Sliding_window_size must be set if Local Sliding attention type") mask &= mask_module.LocalMask( shape=(query.shape[2], key.shape[2]), - window_size=(self.sliding_window_size, self.sliding_window_size), + window_size=(self.sliding_window_size - 1, self.sliding_window_size), offset=0, ) elif self.attention_type == AttentionType.CHUNK: diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 43bc5e3c97..45cd1c4d27 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -25,6 +25,7 @@ from flax import nnx import jax import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.sharding import AxisType, Mesh from maxtext.utils import maxtext_utils from maxtext.common.gcloud_stub import is_decoupled @@ -56,6 +57,24 @@ from tests.utils.test_helpers import get_test_config_path +class SplashLocalMaskTest(unittest.TestCase): + """Tests for Splash local masks.""" + + def test_local_window_matches_dense_mask(self): + seq_len = 8 + window_size = 3 + mask = splash_attention_mask.CausalMask((seq_len, seq_len)) & splash_attention_mask.LocalMask( + (seq_len, seq_len), + window_size=(window_size - 1, window_size), + offset=0, + ) + q_sequence = np.arange(seq_len)[:, None] + kv_sequence = np.arange(seq_len)[None, :] + expected_mask = (kv_sequence <= q_sequence) & (kv_sequence > q_sequence - window_size) + + np.testing.assert_array_equal(mask[:, :], expected_mask) + + class BidirectionalBlockMaskTest(unittest.TestCase): """Test for make_bidirectional_block_mask."""