Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 5 additions & 223 deletions src/maxdiffusion/kernels/custom_splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,28 @@
"""Custom Pallas flash attention kernel for TPU."""

import functools
import math

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
NUM_LANES = 128
NUM_SUBLANES = 8
NT_DIM_NUMBERS = (((1,), (1,)), ((), ()))

# Default block sizes (tuned for 720p Wan2.1 on v6e/v7x)
DEFAULT_BQSIZE = 3328
DEFAULT_BKVSIZE = 2816
# Cranked up to 1024 for massive MXU throughput
DEFAULT_BKVCOMPUTESIZE = 1024
# Kept at 256 to protect VPU registers (V1 Optimization)
DEFAULT_BKVCOMPUTEINSIZE = 256


class _BlockSizes:
__slots__ = ("block_q", "block_kv", "block_kv_compute")
__slots__ = ("block_q", "block_kv", "block_kv_compute", "block_kv_compute_in")

def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = None):
def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = None, block_kv_compute_in: int = 256):
self.block_q = block_q
self.block_kv = block_kv
self.block_kv_compute = block_kv_compute if block_kv_compute is not None else block_kv
self.block_kv_compute_in = block_kv_compute_in


def _flash_attention_kernel(
Expand All @@ -62,12 +52,10 @@ def _flash_attention_kernel(
*,
mask_value: float,
grid_width: int,
bq: int,
bkv: int,
bkv_compute: int,
bkv_compute_in: int,
head_dim_v: int,
q_seq_len: int,
kv_seq_len: int,
use_base2_exp: bool = True,
):
Expand Down Expand Up @@ -207,12 +195,10 @@ def _flash_attention_kernel_mhpt(
*,
mask_value: float,
grid_width: int,
bq: int,
bkv: int,
bkv_compute: int,
bkv_compute_in: int,
head_dim_v: int,
q_seq_len: int,
kv_seq_len: int,
heads_per_tile: int,
use_base2_exp: bool = True,
Expand Down Expand Up @@ -354,7 +340,6 @@ def _splash_attention_forward(
k: jax.Array,
v: jax.Array,
block_sizes: _BlockSizes,
bkv_compute_in: int,
q_seq_len: int | None = None,
kv_seq_len: int | None = None,
use_base2_exp: bool = True,
Expand All @@ -365,6 +350,7 @@ def _splash_attention_forward(
head_dim_v = v.shape[-1]
bq, bkv = block_sizes.block_q, block_sizes.block_kv
bkv_compute = block_sizes.block_kv_compute
bkv_compute_in = block_sizes.block_kv_compute_in
num_kv_heads = k.shape[0]
padded_kv_seq_len = k.shape[1]

Expand Down Expand Up @@ -410,12 +396,10 @@ def v_index_map(h, i, j, *_):
_flash_attention_kernel,
mask_value=DEFAULT_MASK_VALUE,
grid_width=grid_width,
bq=bq,
bkv=bkv,
bkv_compute=bkv_compute,
bkv_compute_in=bkv_compute_in,
head_dim_v=head_dim_v,
q_seq_len=actual_q_seq_len,
kv_seq_len=actual_kv_seq_len,
use_base2_exp=use_base2_exp,
),
Expand All @@ -442,7 +426,6 @@ def _splash_attention_forward_mhpt(
k: jax.Array,
v: jax.Array,
block_sizes: _BlockSizes,
bkv_compute_in: int,
heads_per_tile: int,
q_seq_len: int | None = None,
kv_seq_len: int | None = None,
Expand All @@ -454,6 +437,7 @@ def _splash_attention_forward_mhpt(
head_dim_v = v.shape[-1]
bq, bkv = block_sizes.block_q, block_sizes.block_kv
bkv_compute = block_sizes.block_kv_compute
bkv_compute_in = block_sizes.block_kv_compute_in
num_kv_heads = k.shape[0]
actual_q_seq_len = q_seq_len if q_seq_len is not None else padded_q_seq_len
actual_kv_seq_len = kv_seq_len if kv_seq_len is not None else k.shape[1]
Expand Down Expand Up @@ -500,12 +484,10 @@ def out_index_map(h, i, j, *_):
_flash_attention_kernel_mhpt,
mask_value=DEFAULT_MASK_VALUE,
grid_width=grid_width,
bq=bq,
bkv=bkv,
bkv_compute=bkv_compute,
bkv_compute_in=bkv_compute_in,
head_dim_v=head_dim_v,
q_seq_len=actual_q_seq_len,
kv_seq_len=actual_kv_seq_len,
heads_per_tile=hpt,
use_base2_exp=use_base2_exp,
Expand All @@ -530,7 +512,6 @@ def out_index_map(h, i, j, *_):

def make_splash_mha(
block_sizes: _BlockSizes,
bkv_compute_in: int = DEFAULT_BKVCOMPUTEINSIZE,
orig_q_seq_len: int | None = None,
orig_kv_seq_len: int | None = None,
heads_per_tile: int = 1,
Expand All @@ -545,7 +526,6 @@ def _splash_attention(q, k, v):
k,
v,
block_sizes,
bkv_compute_in,
heads_per_tile,
q_seq_len=orig_q_seq_len,
kv_seq_len=orig_kv_seq_len,
Expand All @@ -558,7 +538,6 @@ def _splash_attention(q, k, v):
k,
v,
block_sizes,
bkv_compute_in,
q_seq_len=orig_q_seq_len,
kv_seq_len=orig_kv_seq_len,
use_base2_exp=use_base2_exp,
Expand All @@ -567,200 +546,3 @@ def _splash_attention(q, k, v):
)

return _splash_attention


# ---------------------------------------------------------------------------
# High-level attention function with shard_map
# ---------------------------------------------------------------------------


def tpu_custom_attention(
query,
key,
value,
mesh,
*,
scale=None,
block_q=None,
block_kv=None,
block_kv_compute=None,
block_kv_compute_in=None,
heads_per_tile=None,
use_base2_exp=True,
use_experimental_scheduler=False,
vmem_limit_bytes=None,
flash_block_sizes=None,
):
_LOG2_E = 1.44269504
num_heads = query.shape[1]

if flash_block_sizes is not None:
block_q = flash_block_sizes.get("block_q", block_q)
block_kv = flash_block_sizes.get("block_kv", block_kv)
block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute)
block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)

block_q = block_q if block_q is not None else DEFAULT_BQSIZE
block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE
block_kv_compute = block_kv_compute if block_kv_compute is not None else DEFAULT_BKVCOMPUTESIZE
block_kv_compute_in = block_kv_compute_in if block_kv_compute_in is not None else DEFAULT_BKVCOMPUTEINSIZE
heads_per_tile = heads_per_tile if heads_per_tile is not None else 1

def _attention_on_slices(q, k, v):
scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale
if use_base2_exp:
q = q * scale_factor * _LOG2_E
else:
q = q * scale_factor

def _pad_to_multiple(x, multiple, axis):
seq_len = x.shape[axis]
pad_len = (multiple - seq_len % multiple) % multiple
if pad_len == 0:
return x, seq_len
pad_width = [(0, 0)] * x.ndim
pad_width[axis] = (0, pad_len)
return jnp.pad(x, pad_width), seq_len

def _kernel_3d(q_3d, k_3d, v_3d):
q_orig_len = q_3d.shape[1]
kv_orig_len = k_3d.shape[1]

q_3d_padded, _ = _pad_to_multiple(q_3d, block_q, axis=1)
k_3d_padded, _ = _pad_to_multiple(k_3d, block_kv, axis=1)
v_3d_padded, _ = _pad_to_multiple(v_3d, block_kv, axis=1)

padded_q_seq_len = q_3d_padded.shape[1]
padded_kv_seq_len = k_3d_padded.shape[1]

bsizes = _BlockSizes(
block_q=min(block_q, padded_q_seq_len),
block_kv=min(block_kv, padded_kv_seq_len),
block_kv_compute=min(block_kv_compute, padded_kv_seq_len),
)
splash_kernel = make_splash_mha(
block_sizes=bsizes,
bkv_compute_in=block_kv_compute_in,
orig_q_seq_len=q_orig_len,
orig_kv_seq_len=kv_orig_len,
heads_per_tile=heads_per_tile,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
)
out = splash_kernel(
q_3d_padded.astype(jnp.bfloat16),
k_3d_padded,
v_3d_padded,
)
out = jnp.swapaxes(out, 1, 2)
return out[:, :q_orig_len, ...]

return jax.vmap(_kernel_3d, in_axes=(0, 0, 0), out_axes=0)(q, k, v)

batch_size = query.shape[0]
if num_heads < mesh.size:
q_partition_spec = P()
kv_partition_spec = P()
out_constraint = P()
else:
axis_names = mesh.axis_names
if len(axis_names) == 1:
tp_axis = axis_names[0]
q_partition_spec = P(None, tp_axis, None, None)
kv_partition_spec = P(None, tp_axis, None, None)
out_constraint = P(None, None, tp_axis, None)
elif len(axis_names) == 2:
dp_axis, tp_axis = axis_names[0], axis_names[1]
dp_size = mesh.shape[dp_axis]
if batch_size >= dp_size:
q_partition_spec = P(dp_axis, tp_axis, None, None)
kv_partition_spec = P(dp_axis, tp_axis, None, None)
out_constraint = P(dp_axis, None, tp_axis, None)
else:
all_axes = tuple(axis_names)
q_partition_spec = P(None, all_axes, None, None)
kv_partition_spec = P(None, all_axes, None, None)
out_constraint = P(None, None, all_axes, None)
else:
q_partition_spec = P(axis_names[0], axis_names[1], axis_names[2], None)
kv_partition_spec = P(axis_names[0], axis_names[1], None, None)
out_constraint = P(axis_names[0], None, (axis_names[1], axis_names[2]), None)

sharded_fn = shard_map(
_attention_on_slices,
mesh=mesh,
in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec),
out_specs=q_partition_spec,
check_rep=False,
)
out = sharded_fn(query, key, value)
out = jax.lax.with_sharding_constraint(out, out_constraint)
return out


# ---------------------------------------------------------------------------
# TorchAX SDPA wrapper
# ---------------------------------------------------------------------------


def make_custom_splash_sdpa(mesh, env, **kwargs):
flash_block_sizes = kwargs.get("flash_block_sizes", None)
bq = kwargs.get("block_q", DEFAULT_BQSIZE)
bkv = kwargs.get("block_kv", DEFAULT_BKVSIZE)
bkv_compute = kwargs.get("block_kv_compute", DEFAULT_BKVCOMPUTESIZE)
bkv_compute_in = kwargs.get("block_kv_compute_in", DEFAULT_BKVCOMPUTEINSIZE)
hpt = kwargs.get("heads_per_tile", 1)
use_k_smooth = kwargs.get("use_k_smooth", True)
use_base2_exp = kwargs.get("use_base2_exp", True)
use_experimental_scheduler = kwargs.get("use_experimental_scheduler", False)
vmem_limit_bytes = kwargs.get("vmem_limit_bytes", None)

def _simple_attention(q, k, v, scale=None):
s = scale if scale is not None else 1.0 / math.sqrt(q.shape[-1])
attn = jnp.einsum("bhsd,bhtd->bhst", q * s, k)
attn = jax.nn.softmax(attn.astype(jnp.float32), axis=-1).astype(q.dtype)
return jnp.einsum("bhst,bhtd->bhsd", attn, v)

def _sdpa(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False,
):
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
num_heads = jquery.shape[1]

if num_heads <= 8:
result = _simple_attention(jquery, jkey, jvalue, scale=scale)
return env.j2t_iso(result)

if use_k_smooth:
key_mean = jnp.mean(jkey, axis=2, keepdims=True)
jkey = jkey - key_mean

result = tpu_custom_attention(
jquery,
jkey,
jvalue,
mesh,
scale=scale,
block_q=bq,
block_kv=bkv,
block_kv_compute=bkv_compute,
block_kv_compute_in=bkv_compute_in,
heads_per_tile=hpt,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
flash_block_sizes=flash_block_sizes,
)
return env.j2t_iso(result)

return _sdpa
4 changes: 1 addition & 3 deletions src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def translate_fn(nnx_path_str):
# the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers
unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict)
if unmatched_keys:
max_logging.log(
f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}"
)
max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}")

return pipeline
5 changes: 3 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,12 @@ def wrap_ulysses_attention(query, key, value):
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
value, _, _ = _pad_data_for_flash(value, heads, bkv)

bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute)
bsizes = custom_splash._BlockSizes(
block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute, block_kv_compute_in=bkv_compute_in
)

splash_kernel = custom_splash.make_splash_mha(
block_sizes=bsizes,
bkv_compute_in=bkv_compute_in,
orig_q_seq_len=query_seq_len,
orig_kv_seq_len=key_seq_len,
heads_per_tile=heads_per_tile,
Expand Down
Loading