Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def translate_fn(nnx_path_str):
# the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers
unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict)
if unmatched_keys:
max_logging.log(
f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}"
)
max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}")

return pipeline
3 changes: 2 additions & 1 deletion src/maxdiffusion/models/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax.numpy as jnp
from jax.tree_util import tree_flatten_with_path, tree_unflatten
from typing import Tuple, Sequence
from maxdiffusion import max_logging

# Params used to define mixed precision quantization configs
DEFAULT = "__default__" # default config
Expand Down Expand Up @@ -139,7 +140,7 @@ def _get_quant_config(config):
else:
drhs_bits = 8
drhs_accumulator_dtype = jnp.int32
print(config.quantization_local_shard_count) # -1
max_logging.log(config.quantization_local_shard_count) # -1
drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count)
return aqt_config.config_v4(
fwd_bits=8,
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,8 @@ def _prepare_model_inputs_i2v(

prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
image_embeds = jax.device_put(image_embeds, data_sharding)
if image_embeds is not None:
image_embeds = jax.device_put(image_embeds, data_sharding)

return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size

Expand Down
22 changes: 19 additions & 3 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
import time
from ... import max_utils
from maxdiffusion import max_logging


class WanPipeline2_1(WanPipeline):
Expand Down Expand Up @@ -101,6 +102,7 @@ def __call__(
magcache_K: Optional[int] = None,
retention_ratio: Optional[float] = None,
use_kv_cache: bool = False,
output_type: str = "pil",
):
config = getattr(self, "config", None)
if max_sequence_length is None:
Expand Down Expand Up @@ -170,6 +172,9 @@ def __call__(
latents.block_until_ready()
trace["denoise_total"] = time.perf_counter() - t_denoise_start

if output_type == "latent":
return latents, trace

t_decode_start = time.perf_counter()
video = self._decode_latents_to_video(latents, trace=trace)
if hasattr(video, "block_until_ready"):
Expand Down Expand Up @@ -222,6 +227,18 @@ def run_inference_2_1(
do_cfg = guidance_scale > 1.0
bsz = latents.shape[0]

data_shards = 1
try:
if hasattr(latents, "sharding") and hasattr(latents.sharding, "mesh"):
data_shards = latents.sharding.mesh.shape["data"] * latents.sharding.mesh.shape.get("fsdp", 1)
except Exception:
pass

if use_cfg_cache and do_cfg and bsz % data_shards != 0:
max_logging.log(
f"Warning: Disabling CFG cache because batch size {bsz} is not divisible by data shards {data_shards}. This often happens with data_parallelism > 1 and per_device_batch_size = 1."
)
use_cfg_cache = False
# Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
if height >= 720:
# 720p: conservative — protect last 40%, interval=5
Expand Down Expand Up @@ -306,10 +323,9 @@ def run_inference_2_1(
)

scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)

if scan_diffusion_loop and not use_magcache and not use_cfg_cache:
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)

scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))

def scan_body(carry, t):
Expand Down Expand Up @@ -365,7 +381,7 @@ def scan_body(carry, t):
profiler = max_utils.Profiler(config)
profiler.start()

t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t = timesteps[step]

if use_magcache and do_cfg:
timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz)
Expand Down
53 changes: 35 additions & 18 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __call__(
use_cfg_cache: bool = False,
use_sen_cache: bool = False,
use_kv_cache: bool = False,
output_type: str = "pil",
):
config = getattr(self, "config", None)
if max_sequence_length is None:
Expand Down Expand Up @@ -203,6 +204,9 @@ def __call__(
latents.block_until_ready()
trace["denoise_total"] = time.perf_counter() - t_denoise_start

if output_type == "latent":
return latents, trace

t_decode_start = time.perf_counter()
video = self._decode_latents_to_video(latents, trace=trace)
if hasattr(video, "block_until_ready"):
Expand Down Expand Up @@ -252,6 +256,19 @@ def run_inference_2_2(
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
bsz = latents.shape[0]

data_shards = 1
try:
if hasattr(latents, "sharding") and hasattr(latents.sharding, "mesh"):
data_shards = latents.sharding.mesh.shape["data"] * latents.sharding.mesh.shape.get("fsdp", 1)
except Exception:
pass

if use_cfg_cache and do_classifier_free_guidance and bsz % data_shards != 0:
max_logging.log(
f"Warning: Disabling CFG cache because batch size {bsz} is not divisible by data shards {data_shards}. This often happens with data_parallelism > 1 and per_device_batch_size = 1."
)
use_cfg_cache = False

prompt_embeds_combined = (
jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds
)
Expand Down Expand Up @@ -279,6 +296,8 @@ def run_inference_2_2(
high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest)
kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined)

timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)

# ── SenCache path (arXiv:2602.24208) ──
if use_sen_cache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
Expand All @@ -303,16 +322,18 @@ def run_inference_2_2(
num_train_timesteps = float(scheduler.config.num_train_timesteps)

# SenCache state
ref_noise_pred = None # y^r: cached denoiser output
ref_latent = None # x^r: latent at last cache refresh
ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh
accum_dx = 0.0 # accumulated ||Δx|| since last refresh
accum_dt = 0.0 # accumulated |Δt| since last refresh
reuse_count = 0 # consecutive cache reuses
cache_count = 0
ref_noise_pred = jnp.zeros(
(bsz * 2, latents.shape[1], latents.shape[2], latents.shape[3], latents.shape[4]), dtype=latents.dtype
)
ref_latent = jnp.zeros_like(latents)
ref_timestep = jnp.array(0.0, dtype=jnp.float32)
accum_dx = jnp.array(0.0, dtype=jnp.float32)
accum_dt = jnp.array(0.0, dtype=jnp.float32)
reuse_count = jnp.array(0, dtype=jnp.int32)
cache_count = jnp.array(0, dtype=jnp.int32)

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t = timesteps[step]
t_float = float(timesteps_np[step]) / num_train_timesteps # normalize to [0, 1]

# Select transformer and guidance scale
Expand Down Expand Up @@ -358,10 +379,10 @@ def run_inference_2_2(
)
ref_noise_pred = noise_pred
ref_latent = latents
ref_timestep = t_float
accum_dx = 0.0
accum_dt = 0.0
reuse_count = 0
ref_timestep = jnp.array(t_float, dtype=jnp.float32)
accum_dx = jnp.array(0.0, dtype=jnp.float32)
accum_dt = jnp.array(0.0, dtype=jnp.float32)
reuse_count = jnp.array(0, dtype=jnp.int32)
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
continue

Expand All @@ -375,12 +396,10 @@ def run_inference_2_2(
score = alpha_x * accum_dx + alpha_t * accum_dt

if score <= sen_epsilon and reuse_count < max_reuse:
# Cache hit: reuse previous output
noise_pred = ref_noise_pred
reuse_count += 1
cache_count += 1
else:
# Cache miss: full CFG forward pass
latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, _ = transformer_forward_pass_full_cfg(
Expand Down Expand Up @@ -470,7 +489,7 @@ def run_inference_2_2(
cached_noise_uncond = None

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t = timesteps[step]
is_cache_step = step_is_cache[step]

# Select transformer and guidance scale based on precomputed schedule
Expand Down Expand Up @@ -607,8 +626,6 @@ def low_noise_branch(operands):
)

if scan_diffusion_loop:
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)

scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))

def scan_body(carry, t):
Expand Down Expand Up @@ -657,7 +674,7 @@ def scan_body(carry, t):
profiler = max_utils.Profiler(config)
profiler.start()

t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t = timesteps[step]

if step_uses_high[step]:
graphdef, state, rest = (
Expand Down
23 changes: 11 additions & 12 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,19 +213,17 @@ def __call__(
last_image,
)

def _process_image_input(img_input, height, width, num_videos_per_prompt):
def _process_image_input(img_input, height, width):
if img_input is None:
return None
tensor = self.video_processor.preprocess(img_input, height=height, width=width)
jax_array = jnp.array(tensor.cpu().numpy())
if jax_array.ndim == 3:
jax_array = jax_array[None, ...] # Add batch dimension
if num_videos_per_prompt > 1:
jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0)
return jax_array

image_tensor = _process_image_input(image, height, width, effective_batch_size)
last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size)
image_tensor = _process_image_input(image, height, width)
last_image_tensor = _process_image_input(last_image, height, width)

if rng is None:
rng = jax.random.key(self.config.seed)
Expand Down Expand Up @@ -352,6 +350,8 @@ def run_inference_2_1_i2v(
image_embeds_combined = image_embeds
condition_combined = condition

condition_combined = jnp.transpose(condition_combined, (0, 4, 1, 2, 3))

transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state)

# Compute RoPE once as it only depends on shape
Expand All @@ -373,10 +373,9 @@ def run_inference_2_1_i2v(
)

scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)

if scan_diffusion_loop and not use_magcache:
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)

scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))

def scan_body(carry, t):
Expand All @@ -386,9 +385,9 @@ def scan_body(carry, t):
if do_cfg:
latents_input = jnp.concatenate([current_latents, current_latents], axis=0)

latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1)
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))
latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=1)
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))

outputs = transformer_forward_pass(
graphdef,
Expand Down Expand Up @@ -429,7 +428,7 @@ def scan_body(carry, t):
profiler = max_utils.Profiler(config)
profiler.start()

t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t = timesteps[step]

skip_blocks = False
if use_magcache and do_cfg:
Expand All @@ -446,9 +445,9 @@ def scan_body(carry, t):
if do_cfg:
latents_input = jnp.concatenate([latents, latents], axis=0)

latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1)
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))
latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=1)
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))

outputs = transformer_forward_pass(
graphdef,
Expand Down
Loading
Loading