diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py index 1fad541c..a3c4d0d3 100644 --- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -76,8 +76,6 @@ def translate_fn(nnx_path_str): # the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict) if unmatched_keys: - max_logging.log( - f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}" - ) + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}") return pipeline diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py index 766842d1..fbc00a05 100644 --- a/src/maxdiffusion/models/quantizations.py +++ b/src/maxdiffusion/models/quantizations.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from jax.tree_util import tree_flatten_with_path, tree_unflatten from typing import Tuple, Sequence +from maxdiffusion import max_logging # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config @@ -139,7 +140,7 @@ def _get_quant_config(config): else: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 - print(config.quantization_local_shard_count) # -1 + max_logging.log(config.quantization_local_shard_count) # -1 drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count) return aqt_config.config_v4( fwd_bits=8, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8b0493ed..06982152 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -997,7 +997,8 @@ def _prepare_model_inputs_i2v( prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) - image_embeds = jax.device_put(image_embeds, data_sharding) + if image_embeds is not None: + image_embeds = jax.device_put(image_embeds, data_sharding) return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 85daec33..7d2ac763 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -25,6 +25,7 @@ import numpy as np import time from ... import max_utils +from maxdiffusion import max_logging class WanPipeline2_1(WanPipeline): @@ -101,6 +102,7 @@ def __call__( magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, use_kv_cache: bool = False, + output_type: str = "pil", ): config = getattr(self, "config", None) if max_sequence_length is None: @@ -170,6 +172,9 @@ def __call__( latents.block_until_ready() trace["denoise_total"] = time.perf_counter() - t_denoise_start + if output_type == "latent": + return latents, trace + t_decode_start = time.perf_counter() video = self._decode_latents_to_video(latents, trace=trace) if hasattr(video, "block_until_ready"): @@ -222,6 +227,18 @@ def run_inference_2_1( do_cfg = guidance_scale > 1.0 bsz = latents.shape[0] + data_shards = 1 + try: + if hasattr(latents, "sharding") and hasattr(latents.sharding, "mesh"): + data_shards = latents.sharding.mesh.shape["data"] * latents.sharding.mesh.shape.get("fsdp", 1) + except Exception: + pass + + if use_cfg_cache and do_cfg and bsz % data_shards != 0: + max_logging.log( + f"Warning: Disabling CFG cache because batch size {bsz} is not divisible by data shards {data_shards}. This often happens with data_parallelism > 1 and per_device_batch_size = 1." + ) + use_cfg_cache = False # Resolution-dependent CFG cache config (FasterCache / MixCache guidance) if height >= 720: # 720p: conservative — protect last 40%, interval=5 @@ -306,10 +323,9 @@ def run_inference_2_1( ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) if scan_diffusion_loop and not use_magcache and not use_cfg_cache: - timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) - scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) def scan_body(carry, t): @@ -365,7 +381,7 @@ def scan_body(carry, t): profiler = max_utils.Profiler(config) profiler.start() - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] if use_magcache and do_cfg: timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 00d11f96..0874f4bd 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -118,6 +118,7 @@ def __call__( use_cfg_cache: bool = False, use_sen_cache: bool = False, use_kv_cache: bool = False, + output_type: str = "pil", ): config = getattr(self, "config", None) if max_sequence_length is None: @@ -203,6 +204,9 @@ def __call__( latents.block_until_ready() trace["denoise_total"] = time.perf_counter() - t_denoise_start + if output_type == "latent": + return latents, trace + t_decode_start = time.perf_counter() video = self._decode_latents_to_video(latents, trace=trace) if hasattr(video, "block_until_ready"): @@ -252,6 +256,19 @@ def run_inference_2_2( do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + data_shards = 1 + try: + if hasattr(latents, "sharding") and hasattr(latents.sharding, "mesh"): + data_shards = latents.sharding.mesh.shape["data"] * latents.sharding.mesh.shape.get("fsdp", 1) + except Exception: + pass + + if use_cfg_cache and do_classifier_free_guidance and bsz % data_shards != 0: + max_logging.log( + f"Warning: Disabling CFG cache because batch size {bsz} is not divisible by data shards {data_shards}. This often happens with data_parallelism > 1 and per_device_batch_size = 1." + ) + use_cfg_cache = False + prompt_embeds_combined = ( jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds ) @@ -279,6 +296,8 @@ def run_inference_2_2( high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined) + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -303,16 +322,18 @@ def run_inference_2_2( num_train_timesteps = float(scheduler.config.num_train_timesteps) # SenCache state - ref_noise_pred = None # y^r: cached denoiser output - ref_latent = None # x^r: latent at last cache refresh - ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh - accum_dx = 0.0 # accumulated ||Δx|| since last refresh - accum_dt = 0.0 # accumulated |Δt| since last refresh - reuse_count = 0 # consecutive cache reuses - cache_count = 0 + ref_noise_pred = jnp.zeros( + (bsz * 2, latents.shape[1], latents.shape[2], latents.shape[3], latents.shape[4]), dtype=latents.dtype + ) + ref_latent = jnp.zeros_like(latents) + ref_timestep = jnp.array(0.0, dtype=jnp.float32) + accum_dx = jnp.array(0.0, dtype=jnp.float32) + accum_dt = jnp.array(0.0, dtype=jnp.float32) + reuse_count = jnp.array(0, dtype=jnp.int32) + cache_count = jnp.array(0, dtype=jnp.int32) for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] t_float = float(timesteps_np[step]) / num_train_timesteps # normalize to [0, 1] # Select transformer and guidance scale @@ -358,10 +379,10 @@ def run_inference_2_2( ) ref_noise_pred = noise_pred ref_latent = latents - ref_timestep = t_float - accum_dx = 0.0 - accum_dt = 0.0 - reuse_count = 0 + ref_timestep = jnp.array(t_float, dtype=jnp.float32) + accum_dx = jnp.array(0.0, dtype=jnp.float32) + accum_dt = jnp.array(0.0, dtype=jnp.float32) + reuse_count = jnp.array(0, dtype=jnp.int32) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() continue @@ -375,12 +396,10 @@ def run_inference_2_2( score = alpha_x * accum_dx + alpha_t * accum_dt if score <= sen_epsilon and reuse_count < max_reuse: - # Cache hit: reuse previous output noise_pred = ref_noise_pred reuse_count += 1 cache_count += 1 else: - # Cache miss: full CFG forward pass latents_doubled = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, bsz * 2) noise_pred, _, _ = transformer_forward_pass_full_cfg( @@ -470,7 +489,7 @@ def run_inference_2_2( cached_noise_uncond = None for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] is_cache_step = step_is_cache[step] # Select transformer and guidance scale based on precomputed schedule @@ -607,8 +626,6 @@ def low_noise_branch(operands): ) if scan_diffusion_loop: - timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) - scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) def scan_body(carry, t): @@ -657,7 +674,7 @@ def scan_body(carry, t): profiler = max_utils.Profiler(config) profiler.start() - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] if step_uses_high[step]: graphdef, state, rest = ( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 3bfcc751..5fc35d8d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -213,19 +213,17 @@ def __call__( last_image, ) - def _process_image_input(img_input, height, width, num_videos_per_prompt): + def _process_image_input(img_input, height, width): if img_input is None: return None tensor = self.video_processor.preprocess(img_input, height=height, width=width) jax_array = jnp.array(tensor.cpu().numpy()) if jax_array.ndim == 3: jax_array = jax_array[None, ...] # Add batch dimension - if num_videos_per_prompt > 1: - jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) return jax_array - image_tensor = _process_image_input(image, height, width, effective_batch_size) - last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) + image_tensor = _process_image_input(image, height, width) + last_image_tensor = _process_image_input(last_image, height, width) if rng is None: rng = jax.random.key(self.config.seed) @@ -352,6 +350,8 @@ def run_inference_2_1_i2v( image_embeds_combined = image_embeds condition_combined = condition + condition_combined = jnp.transpose(condition_combined, (0, 4, 1, 2, 3)) + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) # Compute RoPE once as it only depends on shape @@ -373,10 +373,9 @@ def run_inference_2_1_i2v( ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) if scan_diffusion_loop and not use_magcache: - timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) - scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) def scan_body(carry, t): @@ -386,9 +385,9 @@ def scan_body(carry, t): if do_cfg: latents_input = jnp.concatenate([current_latents, current_latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1) + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) - latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) outputs = transformer_forward_pass( graphdef, @@ -429,7 +428,7 @@ def scan_body(carry, t): profiler = max_utils.Profiler(config) profiler.start() - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] skip_blocks = False if use_magcache and do_cfg: @@ -446,9 +445,9 @@ def scan_body(carry, t): if do_cfg: latents_input = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1) + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) - latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) outputs = transformer_forward_pass( graphdef, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index f071c231..d9a6003b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -237,19 +237,17 @@ def __call__( last_image, ) - def _process_image_input(img_input, height, width, num_videos_per_prompt): + def _process_image_input(img_input, height, width): if img_input is None: return None tensor = self.video_processor.preprocess(img_input, height=height, width=width) jax_array = jnp.array(tensor.cpu().numpy()) if jax_array.ndim == 3: jax_array = jax_array[None, ...] # Add batch dimension - if num_videos_per_prompt > 1: - jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) return jax_array - image_tensor = _process_image_input(image, height, width, effective_batch_size) - last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) + image_tensor = _process_image_input(image, height, width) + last_image_tensor = _process_image_input(last_image, height, width) if rng is None: rng = jax.random.key(self.config.seed) @@ -373,6 +371,19 @@ def run_inference_2_2_i2v( do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + data_shards = 1 + try: + if hasattr(latents, "sharding") and hasattr(latents.sharding, "mesh"): + data_shards = latents.sharding.mesh.shape["data"] * latents.sharding.mesh.shape.get("fsdp", 1) + except Exception: + pass + + if use_cfg_cache and do_classifier_free_guidance and bsz % data_shards != 0: + max_logging.log( + f"Warning: Disabling CFG cache because batch size {bsz} is not divisible by data shards {data_shards}. This often happens with data_parallelism > 1 and per_device_batch_size = 1." + ) + use_cfg_cache = False + prompt_embeds_combined = ( jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds ) @@ -404,6 +415,8 @@ def run_inference_2_2_i2v( prompt_embeds_combined, image_embeds_combined ) + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -422,18 +435,21 @@ def run_inference_2_2_i2v( num_train_timesteps = float(scheduler.config.num_train_timesteps) condition_doubled = jnp.concatenate([condition] * 2) + condition_doubled = jnp.transpose(condition_doubled, (0, 4, 1, 2, 3)) # SenCache state - ref_noise_pred = None - ref_latent = None - ref_timestep = 0.0 - accum_dx = 0.0 - accum_dt = 0.0 - reuse_count = 0 - cache_count = 0 + ref_noise_pred = jnp.zeros( + (bsz * 2, latents.shape[1], latents.shape[2], latents.shape[3], latents.shape[4]), dtype=latents.dtype + ) + ref_latent = jnp.zeros_like(latents) + ref_timestep = jnp.array(0.0, dtype=jnp.float32) + accum_dx = jnp.array(0.0, dtype=jnp.float32) + accum_dt = jnp.array(0.0, dtype=jnp.float32) + reuse_count = jnp.array(0, dtype=jnp.int32) + cache_count = jnp.array(0, dtype=jnp.int32) for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] t_float = float(timesteps_np[step]) / num_train_timesteps if step_uses_high[step]: @@ -462,8 +478,8 @@ def run_inference_2_2_i2v( if force_compute: latents_doubled = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) - latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + latents_doubled = jnp.transpose(latents_doubled, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=1) timestep = jnp.broadcast_to(t, bsz * 2) noise_pred, _, _ = transformer_forward_pass_full_cfg( graphdef, @@ -481,10 +497,10 @@ def run_inference_2_2_i2v( noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred ref_latent = latents - ref_timestep = t_float - accum_dx = 0.0 - accum_dt = 0.0 - reuse_count = 0 + ref_timestep = jnp.array(t_float, dtype=jnp.float32) + accum_dx = jnp.array(0.0, dtype=jnp.float32) + accum_dt = jnp.array(0.0, dtype=jnp.float32) + reuse_count = jnp.array(0, dtype=jnp.int32) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() continue @@ -501,8 +517,8 @@ def run_inference_2_2_i2v( cache_count += 1 else: latents_doubled = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) - latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + latents_doubled = jnp.transpose(latents_doubled, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=1) timestep = jnp.broadcast_to(t, bsz * 2) noise_pred, _, _ = transformer_forward_pass_full_cfg( graphdef, @@ -559,8 +575,9 @@ def run_inference_2_2_i2v( image_embeds_cond = None # Keep condition in both single and doubled forms - condition_cond = condition + condition_cond = jnp.transpose(condition, (0, 4, 1, 2, 3)) condition_doubled = jnp.concatenate([condition] * 2) + condition_doubled = jnp.transpose(condition_doubled, (0, 4, 1, 2, 3)) # Determine the first low-noise step first_low_step = next( @@ -596,7 +613,7 @@ def run_inference_2_2_i2v( cached_noise_uncond = None for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] is_cache_step = step_is_cache[step] if step_uses_high[step]: @@ -621,9 +638,9 @@ def run_inference_2_2_i2v( if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── w1, w2 = step_w1w2[step] - # Prepare cond-only input: concat condition, transpose BFHWC -> BCFHW - latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1) - latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + # Prepare cond-only input: transpose latents to BCFHW and concat with pre-transposed condition + latents_t = jnp.transpose(latents, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_t, condition_cond], axis=1) timestep = jnp.broadcast_to(t, bsz) kv_cache_cond = jax.tree.map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None @@ -647,8 +664,8 @@ def run_inference_2_2_i2v( else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── latents_doubled = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) - latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + latents_doubled = jnp.transpose(latents_doubled, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=1) timestep = jnp.broadcast_to(t, bsz * 2) ( noise_pred, @@ -685,7 +702,6 @@ def high_noise_branch(operands): mask_high, _, ) = operands - latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( high_noise_graphdef, high_noise_state, @@ -714,7 +730,6 @@ def low_noise_branch(operands): _, mask_low, ) = operands - latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( low_noise_graphdef, low_noise_state, @@ -733,6 +748,7 @@ def low_noise_branch(operands): if do_classifier_free_guidance: condition = jnp.concatenate([condition] * 2) + condition = jnp.transpose(condition, (0, 4, 1, 2, 3)) first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 @@ -745,8 +761,6 @@ def low_noise_branch(operands): scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False if scan_diffusion_loop: - timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) - scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) def scan_body(carry, t): @@ -755,7 +769,8 @@ def scan_body(carry, t): latents_input = current_latents if do_classifier_free_guidance: latents_input = jnp.concatenate([current_latents, current_latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_input, condition], axis=1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) use_high_noise = jnp.greater_equal(t, boundary) @@ -795,11 +810,12 @@ def scan_body(carry, t): profiler = max_utils.Profiler(config) profiler.start() - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = timesteps[step] latents_input = latents if do_classifier_free_guidance: latents_input = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + latent_model_input = jnp.concatenate([latents_input, condition], axis=1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) use_high_noise = jnp.greater_equal(t, boundary) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3b2e9dd1..934a51ca 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -368,5 +368,7 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) - print(config.steps) + from maxdiffusion import max_logging + + max_logging.log(config.steps) r = range(config.steps)