From a689f3100640f44770d4e9ba5b3b6cfe891b0210 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 25 Jun 2026 11:21:44 -0700 Subject: [PATCH] Add Qwen-Image DMD2 PTQ support; save quantizer state (amax) without weights Two related changes to the diffusers quantization example. 1) Weight-free quantizer-state checkpoint (all models) The torch checkpoint written by `--quantized-torch-ckpt-save-path` now stores ONLY ModelOpt's quantization state -- the recipe plus the quantizer buffers (amax, pre_quant_scale, ...) -- and NOT the model weights. The weights live in the base HF/diffusers checkpoint and are reloaded there on restore. This uses ModelOpt's own idiom (mirrors plugins/transformers_trainer.py): save: modelopt_state = mto.modelopt_state(model) modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(model) torch.save(modelopt_state, path) restore: modelopt_state = mto.load_modelopt_state(path) weights = modelopt_state.pop("modelopt_state_weights", None) mto.restore_from_modelopt_state(model, modelopt_state) set_quantizer_state_dict(model, weights) Wrapped as utils.save_quantizer_state / restore_quantizer_state and wired into ExportManager.save_checkpoint / restore_checkpoint. Restore auto-applies on top of the freshly-loaded base weights (the pipeline is created before restore). Effect: the artifact drops from a full-model checkpoint to KBs-MBs (a 60-layer Qwen-Image student: 40.8 GB -> 2.0 MB) while amax round-trips bit-identically. 2) qwen-image-dmd2 model type (DMD2 few-step Qwen-Image students) For students distilled by examples/diffusers/fastgen. Reuses the existing Qwen-Image quantization stack from the base branch (filter_func_qwen_image, the block-range recipe, QwenImagePipeline registration) and adds only what differs: - pipeline_manager: load the consolidated student transformer (+ optional EMA) and swap it into the base QwenImagePipeline; stash the few-step sampler config (defaults to the canonical 4-step shift=3 [1.0, 0.9, 0.75, 0.5, 0.0] ODE schedule, guidance_scale=1.0). - calibration: drive the few-step DMD sampler instead of the standard denoising loop, so collected amax matches how the student is actually run. - qwen_image_dmd2_sampler.py (new): vendored compact DMD unroll, bit-aligned with fastgen/inference_dmd2_qwen_image.py; transformer forwards only for calibration (decode=False), optional VAE decode for sanity inference (decode=True). - sanity_check_dmd2.py (new): restore the quantizer-state checkpoint and run one few-step inference to validate the round trip. - quantize.py: import the ONNX export tooling lazily so the calibration + save path runs without onnx_graphsurgeon (e.g. the diffusers/fastgen container). Validated end-to-end (FP8, single GB200): calibrate -> save (2.0 MB) -> restore_quantizer_state reproduces amax bit-identically (e.g. block 30 attn.to_q input amax 1.26e+03) and the restored quantized student renders a finite, non-constant 1024x1024 image. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Jingyu Xin --- .../diffusers/quantization/calibration.py | 20 ++ .../diffusers/quantization/models_utils.py | 19 ++ .../quantization/pipeline_manager.py | 115 ++++++++ examples/diffusers/quantization/quantize.py | 21 +- .../quantization/qwen_image_dmd2_sampler.py | 260 ++++++++++++++++++ .../quantization/sanity_check_dmd2.py | 175 ++++++++++++ examples/diffusers/quantization/utils.py | 38 +++ 7 files changed, 642 insertions(+), 6 deletions(-) create mode 100644 examples/diffusers/quantization/qwen_image_dmd2_sampler.py create mode 100644 examples/diffusers/quantization/sanity_check_dmd2.py diff --git a/examples/diffusers/quantization/calibration.py b/examples/diffusers/quantization/calibration.py index 27b1ec22436..bebc61970a3 100644 --- a/examples/diffusers/quantization/calibration.py +++ b/examples/diffusers/quantization/calibration.py @@ -21,6 +21,7 @@ from models_utils import MODEL_DEFAULTS, ModelType from pipeline_manager import PipelineManager from quantize_config import CalibrationConfig +from qwen_image_dmd2_sampler import dmd2_sample from tqdm import tqdm from utils import load_calib_prompts @@ -95,6 +96,9 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: # Special handling for WAN video models self._run_wan_video_calibration(prompt_batch, extra_args) + elif self.model_type == ModelType.QWEN_IMAGE_DMD2: + # DMD2 students use a custom few-step sampler, not the standard loop. + self._run_qwen_image_dmd2_calibration(prompt_batch) else: common_args = { "prompt": prompt_batch, @@ -105,6 +109,22 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") self.logger.info("Calibration completed successfully") + def _run_qwen_image_dmd2_calibration(self, prompt_batch: list[str]) -> None: + """Calibrate a DMD2 Qwen-Image student via its few-step sampler. + + Drives the same few-step DMD unroll the student was trained/served with + (NOT the standard denoising loop) so the collected activation statistics + are representative of inference. The VAE decode is skipped — calibration + only needs the transformer forwards. + """ + cfg = self.pipeline_manager.dmd_sampler_cfg + if cfg is None: + raise RuntimeError( + "DMD2 sampler config is not set; the qwen-image-dmd2 pipeline must be created " + "via PipelineManager.create_pipeline() before calibration." + ) + dmd2_sample(self.pipe, prompt_batch, decode=False, **cfg) + def _run_wan_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] ) -> None: diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 9d366dc7402..99e8b5e6c12 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -64,6 +64,10 @@ class ModelType(str, Enum): WAN22_T2V_14b = "wan2.2-t2v-14b" WAN22_T2V_5b = "wan2.2-t2v-5b" QWEN_IMAGE = "qwen-image" + # DMD2-distilled few-step Qwen-Image student (from examples/diffusers/fastgen). + # Same architecture as QWEN_IMAGE, but loaded from a consolidated student dir + # and calibrated with the few-step DMD sampler instead of the standard loop. + QWEN_IMAGE_DMD2 = "qwen-image-dmd2" _FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = { @@ -74,6 +78,7 @@ class ModelType(str, Enum): ModelType.WAN22_T2V_14b: filter_func_wan_video, ModelType.WAN22_T2V_5b: filter_func_wan_video, ModelType.QWEN_IMAGE: filter_func_qwen_image, + ModelType.QWEN_IMAGE_DMD2: filter_func_qwen_image, } _VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = { @@ -107,6 +112,11 @@ def get_model_filter_func( ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", ModelType.QWEN_IMAGE: "Qwen/Qwen-Image", + # Base pipeline (VAE / text-encoder / tokenizer / scheduler) for DMD2 students; + # the trained transformer is loaded separately from a consolidated dir via the + # ``student_path`` extra-param. Override with ``--override-model-path`` or + # ``--extra-param base_pipeline_path=...``. + ModelType.QWEN_IMAGE_DMD2: "Qwen/Qwen-Image", } MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = { @@ -122,6 +132,7 @@ def get_model_filter_func( ModelType.WAN22_T2V_14b: WanPipeline, ModelType.WAN22_T2V_5b: WanPipeline, ModelType.QWEN_IMAGE: QwenImagePipeline, + ModelType.QWEN_IMAGE_DMD2: QwenImagePipeline, } # Shared dataset configurations @@ -258,6 +269,14 @@ def get_model_filter_func( }, } +# DMD2 students share Qwen-Image's architecture, so they reuse the same block-range +# recipe, high-precision filter, base pipeline, and calibration dataset. They differ +# only in (a) loading -- a consolidated student dir swapped into the base pipeline +# (PipelineManager._create_qwen_image_dmd2_pipeline) -- and (b) calibration, which +# drives the few-step DMD sampler instead of the standard denoising loop +# (Calibrator._run_qwen_image_dmd2_calibration). Inherit so the recipe stays in sync. +MODEL_DEFAULTS[ModelType.QWEN_IMAGE_DMD2] = {**MODEL_DEFAULTS[ModelType.QWEN_IMAGE]} + def _coerce_extra_param_value(value: str) -> Any: lowered = value.lower() diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index af89ed568ff..4e76ee0d661 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -43,6 +43,9 @@ def __init__(self, config: ModelConfig, logger: logging.Logger): self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling self._transformer: torch.nn.Module | None = None self._video_decoder: torch.nn.Module | None = None + # Few-step sampler config for DMD2 students (populated when loading a + # qwen-image-dmd2 pipeline); consumed by the calibrator / sanity check. + self.dmd_sampler_cfg: dict[str, Any] | None = None @staticmethod def create_pipeline_from( @@ -100,6 +103,11 @@ def create_pipeline(self) -> Any: self.logger.info("LTX-2 pipeline created successfully") return self.pipe + if self.config.model_type == ModelType.QWEN_IMAGE_DMD2: + self.pipe = self._create_qwen_image_dmd2_pipeline() + self.logger.info("Qwen-Image DMD2 pipeline created successfully") + return self.pipe + pipeline_cls = MODEL_PIPELINE[self.config.model_type] if pipeline_cls is None: raise ValueError( @@ -266,6 +274,113 @@ def _create_ltx2_pipeline(self) -> Any: pipeline_kwargs.update(params) return TI2VidTwoStagesPipeline(**pipeline_kwargs) + def _create_qwen_image_dmd2_pipeline(self) -> Any: + """Build a QwenImagePipeline whose transformer is a DMD2-trained student. + + Loads the consolidated student transformer (the ``model/consolidated`` dir + produced by ``examples/diffusers/fastgen`` training), optionally overlays an + EMA shadow, and swaps it into the base Qwen-Image pipeline so the VAE / + text-encoder / tokenizer / scheduler come from the base checkpoint. + + Reads from ``extra_params``: + student_path (required): consolidated student dir. + base_pipeline_path: base Qwen-Image dir/HF id (defaults to the + registry id or ``--override-model-path``). + ema_path: optional ``ema_shadow.pt`` to overlay onto the student. + sample_steps / t_list / sample_type / guidance_scale / max_t: + few-step sampler schedule (defaults match the canonical 4-step + shift=3 student); stashed in ``self.dmd_sampler_cfg``. + """ + from qwen_image_dmd2_sampler import DEFAULT_MAX_T, resolve_schedule + + try: + from diffusers import QwenImagePipeline, QwenImageTransformer2DModel + except ImportError as e: + raise ImportError( + "qwen-image-dmd2 requires a diffusers version providing QwenImagePipeline " + "and QwenImageTransformer2DModel; upgrade diffusers." + ) from e + + params = dict(self.config.extra_params) + student_path = params.get("student_path") + if not student_path: + raise ValueError( + "Missing required extra_param: student_path (the consolidated DMD2 student " + "dir, e.g. .../epoch_4_step_17999/model/consolidated)." + ) + base_pipeline_path = params.get("base_pipeline_path") or self.config.model_path + ema_path = params.get("ema_path") + + default_dtype = self.config.model_dtype["default"] + transformer_dtype = self.config.model_dtype.get("transformer", default_dtype) + if torch.float16 in (default_dtype, transformer_dtype): + self.logger.warning( + "Qwen-Image is trained/served in bfloat16; float16 (Half) can overflow the " + "VAE and produce NaNs. Consider --model-dtype BFloat16." + ) + + self.logger.info("Loading DMD2 student transformer from %s", student_path) + transformer = QwenImageTransformer2DModel.from_pretrained( + student_path, torch_dtype=transformer_dtype + ) + + if ema_path: + self.logger.info("Overlaying EMA shadow from %s", ema_path) + ema_state = torch.load(str(ema_path), map_location="cpu") + shadow = ( + ema_state.get("shadow", ema_state) if isinstance(ema_state, dict) else ema_state + ) + if not isinstance(shadow, dict): + raise ValueError( + f"ema_path content has unexpected type {type(shadow).__name__}; " + "expected dict[str, Tensor]." + ) + missing, unexpected = transformer.load_state_dict(shadow, strict=False) + if unexpected: + self.logger.warning("EMA overlay had %d unexpected key(s)", len(unexpected)) + if missing: + self.logger.warning("EMA overlay missed %d student key(s)", len(missing)) + + transformer.eval() + + self.logger.info( + "Loading base Qwen-Image pipeline from %s (transformer replaced by student)", + base_pipeline_path, + ) + pipe = QwenImagePipeline.from_pretrained( + base_pipeline_path, transformer=transformer, torch_dtype=default_dtype + ) + pipe.set_progress_bar_config(disable=True) + + # Resolve and stash the few-step sampler config. Defaults match the + # canonical 4-step shift=3 student; the schedule MUST match training. + sample_steps = params.get("sample_steps") + sample_steps = int(sample_steps) if sample_steps is not None else 4 + t_list = params.get("t_list") + if isinstance(t_list, str): + t_list = [float(x) for x in t_list.split(",") if x.strip()] + max_t = float(params.get("max_t", DEFAULT_MAX_T)) + schedule = resolve_schedule(t_list, sample_steps, max_t) + defaults = MODEL_DEFAULTS[self.config.model_type].get("inference_extra_args", {}) + self.dmd_sampler_cfg = { + "schedule": schedule, + "sample_type": str(params.get("sample_type", "ode")), + "guidance_scale": float(params.get("guidance_scale", 1.0)), + "negative_prompt": params.get("negative_prompt"), + "height": int(params.get("height", defaults.get("height", 1024))), + "width": int(params.get("width", defaults.get("width", 1024))), + "max_sequence_length": int(params.get("max_sequence_length", 512)), + } + self.logger.info( + "DMD2 few-step sampler: steps=%d schedule=%s sample_type=%s guidance_scale=%s " + "(schedule must match the student's training t_list)", + len(schedule) - 1, + schedule, + self.dmd_sampler_cfg["sample_type"], + self.dmd_sampler_cfg["guidance_scale"], + ) + return pipe + def print_quant_summary(self): for name, backbone in self.iter_backbones(): self.logger.info(f"{name} quantization info:") diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 41a90089129..fa1cc729f92 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -39,7 +39,6 @@ get_model_filter_func, parse_extra_params, ) -from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from pipeline_manager import PipelineManager from quantize_config import ( CalibrationConfig, @@ -51,9 +50,8 @@ QuantFormat, QuantizationConfig, ) -from utils import check_conv_and_mha, check_lora +from utils import check_conv_and_mha, check_lora, restore_quantizer_state, save_quantizer_state -import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.export import export_hf_checkpoint @@ -295,8 +293,11 @@ def save_checkpoint( filename = f"{backbone_name}.pt" if backbone_name else "backbone.pt" target_path = ckpt_path / filename - self.logger.info(f"Saving backbone to {target_path}") - mto.save(backbone, str(target_path)) + # Save ONLY the quantization state (recipe + quantizer buffers incl. amax), + # not the model weights. The weights live in the base HF/diffusers checkpoint + # and are reloaded there on restore; this keeps the artifact tiny. + self.logger.info(f"Saving quantizer state (amax + recipe, no weights) to {target_path}") + save_quantizer_state(backbone, str(target_path)) self.logger.info("Checkpoint saved successfully") @@ -319,6 +320,12 @@ def export_onnx( if not self.config.onnx_dir: return + # Imported lazily: the ONNX export tooling (onnx_graphsurgeon, etc.) is only + # needed when --onnx-dir is set, so the calibration + amax-save path still + # works in environments without the ONNX deps (e.g. the diffusers/fastgen + # container used for Qwen-Image DMD2 students). + from onnx_utils.export import generate_fp8_scales, modelopt_export_sd + self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}") if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone): @@ -362,7 +369,9 @@ def restore_checkpoint(self) -> None: f"Checkpoint not found for '{backbone_name}' in {restore_path}" ) self.logger.info(f"Restoring {backbone_name} from {source_path}") - mto.restore(backbone, str(source_path)) + # The pipeline was just created with the base (unquantized) weights, so + # this re-applies the quantization recipe + amax on top of them. + restore_quantizer_state(backbone, str(source_path)) self.logger.info("Checkpoints restored successfully") diff --git a/examples/diffusers/quantization/qwen_image_dmd2_sampler.py b/examples/diffusers/quantization/qwen_image_dmd2_sampler.py new file mode 100644 index 00000000000..518bc6b47d5 --- /dev/null +++ b/examples/diffusers/quantization/qwen_image_dmd2_sampler.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compact DMD2 few-step sampler for Qwen-Image students. + +This is a vendored, calibration-friendly version of the few-step unroll in +``examples/diffusers/fastgen/inference_dmd2_qwen_image.py``. It is kept here so +the quantization example is self-contained (no cross-example ``sys.path`` +imports) and so calibration can run the **same forward logic the student was +trained/served with** — which is what makes the collected ``amax`` statistics +representative. + +The single :func:`dmd2_sample` entry point serves two callers: + +* **Calibration** (``decode=False``): runs only the transformer forwards of the + DMD unroll and returns ``None``. The VAE / image post-processing is skipped + because quantization only needs the transformer's activation statistics, and + skipping the VAE saves substantial time and memory on the 60-layer student. +* **Sanity inference** (``decode=True``): additionally runs the VAE decode and + returns a list of images, used to confirm a restored (quantized) student + still produces a finite image. + +The math is bit-aligned with the training-time ``_build_student_input`` in +``modelopt/torch/fastgen/methods/dmd.py`` and with the inference reference: + + for (t_cur, t_next) in pairwise(t_list): + v = student(x, t=t_cur, text_emb) # flow at t_cur + x_0 = x - t_cur * v # RF identity -> x_0 estimate + if t_next > 0: + eps = (x - (1 - t_cur) * x_0) / t_cur # ODE: invert RF forward + x = (1 - t_next) * x_0 + t_next * eps # re-noise to t_next + else: + x = x_0 # final step + +``t_list`` MUST match the student's training schedule (e.g. the LightX2V +"shift=3" 4-step shape ``[1.0, 0.9, 0.75, 0.5, 0.0]``); a mismatch produces a +train/inference gap and therefore misleading calibration statistics. +""" + +from __future__ import annotations + +import itertools + +import torch +from diffusers.utils.torch_utils import randn_tensor + +# Canonical 4-step "shift=3" student schedule (LightX2V-Qwen-Image-Lightning +# shape). t_list has student_sample_steps + 1 entries: the first N are the +# timesteps the student is evaluated at, the trailing 0.0 is the terminal the +# final Euler step lands on (NOT an extra evaluation). +DEFAULT_T_LIST: tuple[float, ...] = (1.0, 0.9, 0.75, 0.5, 0.0) +DEFAULT_MAX_T: float = 0.999 + + +def resolve_schedule( + t_list: list[float] | tuple[float, ...] | None, + sample_steps: int | None, + max_t: float = DEFAULT_MAX_T, +) -> list[float]: + """Resolve the sampling schedule (timesteps + terminal 0.0). + + Priority: + 1. An explicit ``t_list`` (must end at 0.0 and have ``sample_steps + 1`` + entries when ``sample_steps`` is given). + 2. ``sample_steps == 1`` -> ``[max_t, 0.0]`` (canonical single-step). + 3. ``sample_steps == 4`` (or None) with no ``t_list`` -> ``DEFAULT_T_LIST``. + 4. Otherwise a linear ``linspace(max_t, 0, sample_steps + 1)`` fallback. + """ + if t_list is not None: + schedule = [float(t) for t in t_list] + if abs(schedule[-1]) > 1e-6: + raise ValueError( + f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0." + ) + if sample_steps is not None and len(schedule) != sample_steps + 1: + raise ValueError( + f"t_list must have sample_steps+1 entries " + f"(got {len(schedule)} for sample_steps={sample_steps})." + ) + return schedule + + if sample_steps == 1: + return [float(max_t), 0.0] + if sample_steps in (None, 4): + return list(DEFAULT_T_LIST) + return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist() + + +@torch.no_grad() +def dmd2_sample( + pipe, + prompt: str | list[str], + *, + schedule: list[float], + sample_type: str = "ode", + guidance_scale: float = 1.0, + negative_prompt: str | list[str] | None = None, + height: int = 1024, + width: int = 1024, + num_images_per_prompt: int = 1, + generator: torch.Generator | None = None, + max_sequence_length: int = 512, + decode: bool = False, + output_type: str = "pil", +) -> list | None: + """Run the DMD few-step unroll on ``pipe.transformer``. + + Args: + pipe: A ``QwenImagePipeline`` whose ``transformer`` is the DMD2 student. + prompt: A prompt or list of prompts (one calibration batch). + schedule: Full timestep schedule incl. trailing 0.0 (see + :func:`resolve_schedule`). + sample_type: ``"ode"`` (deterministic, recover eps via RF identity) or + ``"sde"`` (fresh Gaussian noise between steps). Must match training. + guidance_scale: Inference-time CFG. Leave at ``1.0`` for students trained + with an internalised (non-null) ``dmd2.guidance_scale`` — passing + ``> 1.0`` there would double-apply CFG. + negative_prompt: Negative prompt for CFG; defaults to ``""`` when CFG is + engaged and none is given. + height/width: Output spatial size (must be VAE-compatible). + num_images_per_prompt: Images per prompt. + generator: Optional RNG for reproducible noise. + max_sequence_length: Text-encoder max sequence length. + decode: If ``True`` run VAE decode + post-process and return images. If + ``False`` (calibration) skip the VAE and return ``None``. + output_type: Passed to the image processor when ``decode=True``. + + Returns: + A list of images when ``decode=True``, else ``None``. + """ + if sample_type not in ("ode", "sde"): + raise ValueError(f"sample_type must be 'ode' or 'sde', got {sample_type!r}") + + do_cfg = guidance_scale != 1.0 + if do_cfg and negative_prompt is None: + negative_prompt = "" + + device = pipe.transformer.device + dtype = next(pipe.transformer.parameters()).dtype + + # ---- Encode prompt(s) ------------------------------------------------ + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + neg_prompt_embeds = neg_prompt_embeds_mask = None + if do_cfg: + neg_prompt_embeds, neg_prompt_embeds_mask = pipe.encode_prompt( + prompt=negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + txt_seq_lens = ( + prompt_embeds_mask.sum(dim=1).int().tolist() if prompt_embeds_mask is not None else None + ) + neg_txt_seq_lens = ( + neg_prompt_embeds_mask.sum(dim=1).int().tolist() + if neg_prompt_embeds_mask is not None + else None + ) + + # ---- Build initial noisy latents at t = schedule[0] ------------------ + batch_size = (1 if isinstance(prompt, str) else len(prompt)) * num_images_per_prompt + num_channels_latents = pipe.transformer.config.in_channels // 4 # 64 // 4 = 16 + h_lat = 2 * (height // (pipe.vae_scale_factor * 2)) + w_lat = 2 * (width // (pipe.vae_scale_factor * 2)) + latent_shape = (batch_size, 1, num_channels_latents, h_lat, w_lat) + + noise = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + latents_5d = noise * schedule[0] # RF: sigma(t0) = t0 + x_packed = pipe._pack_latents(latents_5d, batch_size, num_channels_latents, h_lat, w_lat) + img_shapes = [[(1, h_lat // 2, w_lat // 2)]] * batch_size + + # ---- DMD few-step unroll (transformer forwards) ---------------------- + for t_cur, t_next in itertools.pairwise(schedule): + timestep = torch.tensor([t_cur], device=device, dtype=dtype).expand(batch_size) + flow_packed = pipe.transformer( + hidden_states=x_packed, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + timestep=timestep, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + guidance=None, + return_dict=False, + )[0] + if do_cfg: + neg_flow_packed = pipe.transformer( + hidden_states=x_packed, + encoder_hidden_states=neg_prompt_embeds, + encoder_hidden_states_mask=neg_prompt_embeds_mask, + timestep=timestep, + img_shapes=img_shapes, + txt_seq_lens=neg_txt_seq_lens, + guidance=None, + return_dict=False, + )[0] + flow_packed = ( + neg_flow_packed.to(torch.float64) + + float(guidance_scale) + * (flow_packed.to(torch.float64) - neg_flow_packed.to(torch.float64)) + ).to(dtype) + + # RF identity: x_0 = x_t - t_cur * v (fp64 for stability). + x0_packed = (x_packed.to(torch.float64) - float(t_cur) * flow_packed.to(torch.float64)).to( + dtype + ) + + if t_next > 1e-6: + if sample_type == "ode": + alpha_cur = 1.0 - float(t_cur) + eps_packed = ( + (x_packed.to(torch.float64) - alpha_cur * x0_packed.to(torch.float64)) + / max(float(t_cur), 1e-6) + ).to(dtype) + else: + eps_packed = torch.randn( + x_packed.shape, generator=generator, device=device, dtype=dtype + ) + alpha_next = 1.0 - float(t_next) + x_packed = ( + alpha_next * x0_packed.to(torch.float64) + + float(t_next) * eps_packed.to(torch.float64) + ).to(dtype) + else: + x_packed = x0_packed + + if not decode: + # Calibration path: transformer forwards already ran; nothing to decode. + return None + + # ---- VAE decode (sanity-inference path only) ------------------------- + x0_5d = pipe._unpack_latents(x_packed, height, width, pipe.vae_scale_factor) + latents_mean = ( + torch.tensor(pipe.vae.config.latents_mean) + .view(1, pipe.vae.config.z_dim, 1, 1, 1) + .to(device=device, dtype=dtype) + ) + latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( + 1, pipe.vae.config.z_dim, 1, 1, 1 + ).to(device=device, dtype=dtype) + x0_scaled = x0_5d / latents_std + latents_mean + image_5d = pipe.vae.decode(x0_scaled, return_dict=False)[0] + image_4d = image_5d[:, :, 0] # Qwen-Image treats images as 1-frame videos + return pipe.image_processor.postprocess(image_4d, output_type=output_type) diff --git a/examples/diffusers/quantization/sanity_check_dmd2.py b/examples/diffusers/quantization/sanity_check_dmd2.py new file mode 100644 index 00000000000..9f78901f20c --- /dev/null +++ b/examples/diffusers/quantization/sanity_check_dmd2.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Restore a quantized DMD2 Qwen-Image student and run one few-step inference. + +Confirms the round trip of the new ``qwen-image-dmd2`` quantization flow: + + 1. Load the base Qwen-Image pipeline with the consolidated student swapped in + (via the same :class:`PipelineManager` path quantize.py uses) -- this brings + the original (unquantized) weights. + 2. Reapply the weight-free quantization checkpoint saved by ``quantize.py`` + (``save_quantizer_state`` -> ``transformer.pt``) via + ``restore_quantizer_state``, which re-applies the quantizer recipe **and the + calibrated amax** buffers on top of the loaded weights. + 3. Run a single few-step DMD inference (with VAE decode) and assert the image + is finite and non-constant. + +This deliberately reuses :class:`PipelineManager` and +:func:`qwen_image_dmd2_sampler.dmd2_sample` so the inference path is identical to +calibration's (minus the VAE decode, which is enabled here). + +Usage:: + + python sanity_check_dmd2.py \\ + --quantized-ckpt ./qwen_dmd2_fp8/transformer.pt \\ + --student-path /.../epoch_4_step_17999/model/consolidated \\ + --base-pipeline-path /.../models/Qwen-Image \\ + --output-png ./qwen_dmd2_fp8/sanity.png +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys + +import torch +from models_utils import ModelType +from pipeline_manager import PipelineManager +from quantize_config import ModelConfig +from qwen_image_dmd2_sampler import dmd2_sample +from utils import restore_quantizer_state + +import modelopt.torch.quantization as mtq + +logger = logging.getLogger("sanity_check_dmd2") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--quantized-ckpt", + required=True, + help="Path to the quantized checkpoint saved by quantize.py (e.g. .../transformer.pt).", + ) + parser.add_argument( + "--student-path", + required=True, + help="Consolidated DMD2 student dir (provides architecture + base weights to restore into).", + ) + parser.add_argument( + "--base-pipeline-path", + default="Qwen/Qwen-Image", + help="Base Qwen-Image dir/HF id for the VAE / text-encoder / tokenizer / scheduler.", + ) + parser.add_argument("--ema-path", default=None, help="Optional EMA shadow overlaid on load.") + parser.add_argument("--output-png", default="./qwen_dmd2_sanity.png") + parser.add_argument("--prompt", default="a small red cube on a white table") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--seed", type=int, default=42) + # Few-step sampler knobs (defaults match the canonical 4-step shift=3 student). + parser.add_argument("--sample-steps", type=int, default=4) + parser.add_argument( + "--t-list", + default=None, + help="Comma-separated schedule incl. trailing 0.0, e.g. '1.0,0.9,0.75,0.5,0.0'.", + ) + parser.add_argument("--sample-type", default="ode", choices=["ode", "sde"]) + parser.add_argument("--guidance-scale", type=float, default=1.0) + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s" + ) + + # 1. Build the base pipeline with the student swapped in (unquantized). + extra_params: dict[str, object] = { + "student_path": args.student_path, + "base_pipeline_path": args.base_pipeline_path, + "sample_steps": args.sample_steps, + "sample_type": args.sample_type, + "guidance_scale": args.guidance_scale, + "height": args.height, + "width": args.width, + } + if args.ema_path: + extra_params["ema_path"] = args.ema_path + if args.t_list: + extra_params["t_list"] = args.t_list + + model_config = ModelConfig( + model_type=ModelType.QWEN_IMAGE_DMD2, + model_dtype={"default": torch.bfloat16}, + backbone=["transformer"], + extra_params=extra_params, + ) + pm = PipelineManager(model_config, logger) + pipe = pm.create_pipeline() + + # 2. Restore the quantized architecture + calibrated amax into the student. + logger.info( + "Restoring quantizer state (amax + recipe) from %s onto the loaded student", + args.quantized_ckpt, + ) + restore_quantizer_state(pipe.transformer, args.quantized_ckpt) + mtq.print_quant_summary(pipe.transformer) + pm.setup_device() + + # 3. One few-step inference (with VAE decode). + gen = torch.Generator(device=pipe.transformer.device).manual_seed(args.seed) + images = dmd2_sample(pipe, [args.prompt], decode=True, generator=gen, **pm.dmd_sampler_cfg) + image = images[0] + + import numpy as np + + arr = np.asarray(image) + stats = { + "prompt": args.prompt, + "quantized_ckpt": args.quantized_ckpt, + "schedule": pm.dmd_sampler_cfg["schedule"], + "image_shape": list(arr.shape), + "image_dtype": str(arr.dtype), + "image_min": float(arr.min()), + "image_max": float(arr.max()), + "image_mean": float(arr.mean()), + "image_std": float(arr.std()), + "is_finite": bool(np.isfinite(arr).all()), + "is_not_constant": bool(arr.std() > 0), + } + + os.makedirs(os.path.dirname(os.path.abspath(args.output_png)), exist_ok=True) + image.save(args.output_png) + with open(args.output_png.replace(".png", "_stats.json"), "w") as f: + json.dump(stats, f, indent=2) + print(json.dumps(stats, indent=2)) + + if not stats["is_finite"]: + logger.error("Sanity check FAILED: image contains non-finite values.") + sys.exit(1) + if not stats["is_not_constant"]: + logger.error("Sanity check FAILED: image is constant (std == 0).") + sys.exit(1) + logger.info( + "Sanity check PASSED: restored quantized student produced a finite image -> %s", + args.output_png, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index c3cfdcd5cdd..9a6b841a52e 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -24,8 +24,13 @@ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear from diffusers.utils import load_image +import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.plugins.diffusion.diffusers import AttentionModuleMixin +from modelopt.torch.quantization.utils.core_utils import ( + get_quantizer_state_dict, + set_quantizer_state_dict, +) USE_PEFT = True try: @@ -193,3 +198,36 @@ def mha_filter_func(name): if hasattr(F, "scaled_dot_product_attention"): mtq.disable_quantizer(backbone, mha_filter_func) + + +def save_quantizer_state(model: torch.nn.Module, path: str) -> None: + """Save ONLY ModelOpt's quantization state -- the recipe plus the quantizer + buffers (amax, pre_quant_scale, ...) -- and NOT the model weights. + + This is the same idiom ModelOpt uses internally (see + ``modelopt.torch.quantization.plugins.transformers_trainer``): the + ``modelopt_state`` (architecture/recipe from :func:`mto.modelopt_state`) is + bundled with the per-quantizer state from + :func:`get_quantizer_state_dict` under the ``modelopt_state_weights`` key. + The resulting checkpoint is tiny (KBs-MBs) and is reloaded on top of the + original (unquantized) model via :func:`restore_quantizer_state`. + """ + modelopt_state = mto.modelopt_state(model) + modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(model) + torch.save(modelopt_state, str(path)) + + +def restore_quantizer_state(model: torch.nn.Module, path: str) -> torch.nn.Module: + """Reload a checkpoint written by :func:`save_quantizer_state` onto ``model``. + + ``model`` must already hold its original (unquantized) weights (e.g. freshly + loaded from the base HF/diffusers checkpoint); this re-applies the + quantization recipe and loads the calibrated amax/quantizer buffers on top. + Mirrors ModelOpt's ``_restore_modelopt_state_with_weights``. + """ + modelopt_state = mto.load_modelopt_state(str(path)) + quantizer_state = modelopt_state.pop("modelopt_state_weights", None) + mto.restore_from_modelopt_state(model, modelopt_state) + if quantizer_state is not None: + set_quantizer_state_dict(model, quantizer_state) + return model