diff --git a/src/maxdiffusion/kernels/custom_splash_attention.py b/src/maxdiffusion/kernels/custom_splash_attention.py index fb50a51a9..32fd001fe 100644 --- a/src/maxdiffusion/kernels/custom_splash_attention.py +++ b/src/maxdiffusion/kernels/custom_splash_attention.py @@ -17,7 +17,6 @@ """Custom Pallas flash attention kernel for TPU.""" import functools -import math import jax import jax.numpy as jnp @@ -25,30 +24,21 @@ from jax import lax from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu -from jax.experimental.shard_map import shard_map -from jax.sharding import PartitionSpec as P DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) NUM_LANES = 128 NUM_SUBLANES = 8 NT_DIM_NUMBERS = (((1,), (1,)), ((), ())) -# Default block sizes (tuned for 720p Wan2.1 on v6e/v7x) -DEFAULT_BQSIZE = 3328 -DEFAULT_BKVSIZE = 2816 -# Cranked up to 1024 for massive MXU throughput -DEFAULT_BKVCOMPUTESIZE = 1024 -# Kept at 256 to protect VPU registers (V1 Optimization) -DEFAULT_BKVCOMPUTEINSIZE = 256 - class _BlockSizes: - __slots__ = ("block_q", "block_kv", "block_kv_compute") + __slots__ = ("block_q", "block_kv", "block_kv_compute", "block_kv_compute_in") - def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = None): + def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = None, block_kv_compute_in: int = 256): self.block_q = block_q self.block_kv = block_kv self.block_kv_compute = block_kv_compute if block_kv_compute is not None else block_kv + self.block_kv_compute_in = block_kv_compute_in def _flash_attention_kernel( @@ -62,12 +52,10 @@ def _flash_attention_kernel( *, mask_value: float, grid_width: int, - bq: int, bkv: int, bkv_compute: int, bkv_compute_in: int, head_dim_v: int, - q_seq_len: int, kv_seq_len: int, use_base2_exp: bool = True, ): @@ -207,12 +195,10 @@ def _flash_attention_kernel_mhpt( *, mask_value: float, grid_width: int, - bq: int, bkv: int, bkv_compute: int, bkv_compute_in: int, head_dim_v: int, - q_seq_len: int, kv_seq_len: int, heads_per_tile: int, use_base2_exp: bool = True, @@ -354,7 +340,6 @@ def _splash_attention_forward( k: jax.Array, v: jax.Array, block_sizes: _BlockSizes, - bkv_compute_in: int, q_seq_len: int | None = None, kv_seq_len: int | None = None, use_base2_exp: bool = True, @@ -365,6 +350,7 @@ def _splash_attention_forward( head_dim_v = v.shape[-1] bq, bkv = block_sizes.block_q, block_sizes.block_kv bkv_compute = block_sizes.block_kv_compute + bkv_compute_in = block_sizes.block_kv_compute_in num_kv_heads = k.shape[0] padded_kv_seq_len = k.shape[1] @@ -410,12 +396,10 @@ def v_index_map(h, i, j, *_): _flash_attention_kernel, mask_value=DEFAULT_MASK_VALUE, grid_width=grid_width, - bq=bq, bkv=bkv, bkv_compute=bkv_compute, bkv_compute_in=bkv_compute_in, head_dim_v=head_dim_v, - q_seq_len=actual_q_seq_len, kv_seq_len=actual_kv_seq_len, use_base2_exp=use_base2_exp, ), @@ -442,7 +426,6 @@ def _splash_attention_forward_mhpt( k: jax.Array, v: jax.Array, block_sizes: _BlockSizes, - bkv_compute_in: int, heads_per_tile: int, q_seq_len: int | None = None, kv_seq_len: int | None = None, @@ -454,6 +437,7 @@ def _splash_attention_forward_mhpt( head_dim_v = v.shape[-1] bq, bkv = block_sizes.block_q, block_sizes.block_kv bkv_compute = block_sizes.block_kv_compute + bkv_compute_in = block_sizes.block_kv_compute_in num_kv_heads = k.shape[0] actual_q_seq_len = q_seq_len if q_seq_len is not None else padded_q_seq_len actual_kv_seq_len = kv_seq_len if kv_seq_len is not None else k.shape[1] @@ -500,12 +484,10 @@ def out_index_map(h, i, j, *_): _flash_attention_kernel_mhpt, mask_value=DEFAULT_MASK_VALUE, grid_width=grid_width, - bq=bq, bkv=bkv, bkv_compute=bkv_compute, bkv_compute_in=bkv_compute_in, head_dim_v=head_dim_v, - q_seq_len=actual_q_seq_len, kv_seq_len=actual_kv_seq_len, heads_per_tile=hpt, use_base2_exp=use_base2_exp, @@ -530,7 +512,6 @@ def out_index_map(h, i, j, *_): def make_splash_mha( block_sizes: _BlockSizes, - bkv_compute_in: int = DEFAULT_BKVCOMPUTEINSIZE, orig_q_seq_len: int | None = None, orig_kv_seq_len: int | None = None, heads_per_tile: int = 1, @@ -545,7 +526,6 @@ def _splash_attention(q, k, v): k, v, block_sizes, - bkv_compute_in, heads_per_tile, q_seq_len=orig_q_seq_len, kv_seq_len=orig_kv_seq_len, @@ -558,7 +538,6 @@ def _splash_attention(q, k, v): k, v, block_sizes, - bkv_compute_in, q_seq_len=orig_q_seq_len, kv_seq_len=orig_kv_seq_len, use_base2_exp=use_base2_exp, @@ -567,200 +546,3 @@ def _splash_attention(q, k, v): ) return _splash_attention - - -# --------------------------------------------------------------------------- -# High-level attention function with shard_map -# --------------------------------------------------------------------------- - - -def tpu_custom_attention( - query, - key, - value, - mesh, - *, - scale=None, - block_q=None, - block_kv=None, - block_kv_compute=None, - block_kv_compute_in=None, - heads_per_tile=None, - use_base2_exp=True, - use_experimental_scheduler=False, - vmem_limit_bytes=None, - flash_block_sizes=None, -): - _LOG2_E = 1.44269504 - num_heads = query.shape[1] - - if flash_block_sizes is not None: - block_q = flash_block_sizes.get("block_q", block_q) - block_kv = flash_block_sizes.get("block_kv", block_kv) - block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute) - block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in) - heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) - vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes) - - block_q = block_q if block_q is not None else DEFAULT_BQSIZE - block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE - block_kv_compute = block_kv_compute if block_kv_compute is not None else DEFAULT_BKVCOMPUTESIZE - block_kv_compute_in = block_kv_compute_in if block_kv_compute_in is not None else DEFAULT_BKVCOMPUTEINSIZE - heads_per_tile = heads_per_tile if heads_per_tile is not None else 1 - - def _attention_on_slices(q, k, v): - scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale - if use_base2_exp: - q = q * scale_factor * _LOG2_E - else: - q = q * scale_factor - - def _pad_to_multiple(x, multiple, axis): - seq_len = x.shape[axis] - pad_len = (multiple - seq_len % multiple) % multiple - if pad_len == 0: - return x, seq_len - pad_width = [(0, 0)] * x.ndim - pad_width[axis] = (0, pad_len) - return jnp.pad(x, pad_width), seq_len - - def _kernel_3d(q_3d, k_3d, v_3d): - q_orig_len = q_3d.shape[1] - kv_orig_len = k_3d.shape[1] - - q_3d_padded, _ = _pad_to_multiple(q_3d, block_q, axis=1) - k_3d_padded, _ = _pad_to_multiple(k_3d, block_kv, axis=1) - v_3d_padded, _ = _pad_to_multiple(v_3d, block_kv, axis=1) - - padded_q_seq_len = q_3d_padded.shape[1] - padded_kv_seq_len = k_3d_padded.shape[1] - - bsizes = _BlockSizes( - block_q=min(block_q, padded_q_seq_len), - block_kv=min(block_kv, padded_kv_seq_len), - block_kv_compute=min(block_kv_compute, padded_kv_seq_len), - ) - splash_kernel = make_splash_mha( - block_sizes=bsizes, - bkv_compute_in=block_kv_compute_in, - orig_q_seq_len=q_orig_len, - orig_kv_seq_len=kv_orig_len, - heads_per_tile=heads_per_tile, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, - vmem_limit_bytes=vmem_limit_bytes, - ) - out = splash_kernel( - q_3d_padded.astype(jnp.bfloat16), - k_3d_padded, - v_3d_padded, - ) - out = jnp.swapaxes(out, 1, 2) - return out[:, :q_orig_len, ...] - - return jax.vmap(_kernel_3d, in_axes=(0, 0, 0), out_axes=0)(q, k, v) - - batch_size = query.shape[0] - if num_heads < mesh.size: - q_partition_spec = P() - kv_partition_spec = P() - out_constraint = P() - else: - axis_names = mesh.axis_names - if len(axis_names) == 1: - tp_axis = axis_names[0] - q_partition_spec = P(None, tp_axis, None, None) - kv_partition_spec = P(None, tp_axis, None, None) - out_constraint = P(None, None, tp_axis, None) - elif len(axis_names) == 2: - dp_axis, tp_axis = axis_names[0], axis_names[1] - dp_size = mesh.shape[dp_axis] - if batch_size >= dp_size: - q_partition_spec = P(dp_axis, tp_axis, None, None) - kv_partition_spec = P(dp_axis, tp_axis, None, None) - out_constraint = P(dp_axis, None, tp_axis, None) - else: - all_axes = tuple(axis_names) - q_partition_spec = P(None, all_axes, None, None) - kv_partition_spec = P(None, all_axes, None, None) - out_constraint = P(None, None, all_axes, None) - else: - q_partition_spec = P(axis_names[0], axis_names[1], axis_names[2], None) - kv_partition_spec = P(axis_names[0], axis_names[1], None, None) - out_constraint = P(axis_names[0], None, (axis_names[1], axis_names[2]), None) - - sharded_fn = shard_map( - _attention_on_slices, - mesh=mesh, - in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec), - out_specs=q_partition_spec, - check_rep=False, - ) - out = sharded_fn(query, key, value) - out = jax.lax.with_sharding_constraint(out, out_constraint) - return out - - -# --------------------------------------------------------------------------- -# TorchAX SDPA wrapper -# --------------------------------------------------------------------------- - - -def make_custom_splash_sdpa(mesh, env, **kwargs): - flash_block_sizes = kwargs.get("flash_block_sizes", None) - bq = kwargs.get("block_q", DEFAULT_BQSIZE) - bkv = kwargs.get("block_kv", DEFAULT_BKVSIZE) - bkv_compute = kwargs.get("block_kv_compute", DEFAULT_BKVCOMPUTESIZE) - bkv_compute_in = kwargs.get("block_kv_compute_in", DEFAULT_BKVCOMPUTEINSIZE) - hpt = kwargs.get("heads_per_tile", 1) - use_k_smooth = kwargs.get("use_k_smooth", True) - use_base2_exp = kwargs.get("use_base2_exp", True) - use_experimental_scheduler = kwargs.get("use_experimental_scheduler", False) - vmem_limit_bytes = kwargs.get("vmem_limit_bytes", None) - - def _simple_attention(q, k, v, scale=None): - s = scale if scale is not None else 1.0 / math.sqrt(q.shape[-1]) - attn = jnp.einsum("bhsd,bhtd->bhst", q * s, k) - attn = jax.nn.softmax(attn.astype(jnp.float32), axis=-1).astype(q.dtype) - return jnp.einsum("bhst,bhtd->bhsd", attn, v) - - def _sdpa( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, - ): - jquery, jkey, jvalue = env.t2j_iso((query, key, value)) - num_heads = jquery.shape[1] - - if num_heads <= 8: - result = _simple_attention(jquery, jkey, jvalue, scale=scale) - return env.j2t_iso(result) - - if use_k_smooth: - key_mean = jnp.mean(jkey, axis=2, keepdims=True) - jkey = jkey - key_mean - - result = tpu_custom_attention( - jquery, - jkey, - jvalue, - mesh, - scale=scale, - block_q=bq, - block_kv=bkv, - block_kv_compute=bkv_compute, - block_kv_compute_in=bkv_compute_in, - heads_per_tile=hpt, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, - vmem_limit_bytes=vmem_limit_bytes, - flash_block_sizes=flash_block_sizes, - ) - return env.j2t_iso(result) - - return _sdpa diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py index 1fad541c6..a3c4d0d38 100644 --- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -76,8 +76,6 @@ def translate_fn(nnx_path_str): # the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict) if unmatched_keys: - max_logging.log( - f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}" - ) + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}") return pipeline diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index edc9f4f7b..32b388321 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -656,11 +656,12 @@ def wrap_ulysses_attention(query, key, value): key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv) value, _, _ = _pad_data_for_flash(value, heads, bkv) - bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute) + bsizes = custom_splash._BlockSizes( + block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute, block_kv_compute_in=bkv_compute_in + ) splash_kernel = custom_splash.make_splash_mha( block_sizes=bsizes, - bkv_compute_in=bkv_compute_in, orig_q_seq_len=query_seq_len, orig_kv_seq_len=key_seq_len, heads_per_tile=heads_per_tile,