Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
raise ValueError("Sliding_window_size must be set if Local Sliding attention type")
mask &= mask_module.LocalMask(
shape=(query.shape[2], key.shape[2]),
window_size=(self.sliding_window_size, self.sliding_window_size),
window_size=(self.sliding_window_size - 1, self.sliding_window_size),
offset=0,
)
elif self.attention_type == AttentionType.CHUNK:
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flax import nnx
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
from jax.sharding import AxisType, Mesh
from maxtext.utils import maxtext_utils
from maxtext.common.gcloud_stub import is_decoupled
Expand Down Expand Up @@ -56,6 +57,24 @@
from tests.utils.test_helpers import get_test_config_path


class SplashLocalMaskTest(unittest.TestCase):
"""Tests for Splash local masks."""

def test_local_window_matches_dense_mask(self):
seq_len = 8
window_size = 3
mask = splash_attention_mask.CausalMask((seq_len, seq_len)) & splash_attention_mask.LocalMask(
(seq_len, seq_len),
window_size=(window_size - 1, window_size),
offset=0,
)
q_sequence = np.arange(seq_len)[:, None]
kv_sequence = np.arange(seq_len)[None, :]
expected_mask = (kv_sequence <= q_sequence) & (kv_sequence > q_sequence - window_size)

np.testing.assert_array_equal(mask[:, :], expected_mask)


class BidirectionalBlockMaskTest(unittest.TestCase):
"""Test for make_bidirectional_block_mask."""

Expand Down
Loading