Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/diffusers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from models_utils import MODEL_DEFAULTS, ModelType
from pipeline_manager import PipelineManager
from quantize_config import CalibrationConfig
from qwen_image_dmd2_sampler import dmd2_sample
from tqdm import tqdm
from utils import load_calib_prompts

Expand Down Expand Up @@ -95,6 +96,9 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None:
elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]:
# Special handling for WAN video models
self._run_wan_video_calibration(prompt_batch, extra_args)
elif self.model_type == ModelType.QWEN_IMAGE_DMD2:
# DMD2 students use a custom few-step sampler, not the standard loop.
self._run_qwen_image_dmd2_calibration(prompt_batch)
else:
common_args = {
"prompt": prompt_batch,
Expand All @@ -105,6 +109,22 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None:
self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}")
self.logger.info("Calibration completed successfully")

def _run_qwen_image_dmd2_calibration(self, prompt_batch: list[str]) -> None:
"""Calibrate a DMD2 Qwen-Image student via its few-step sampler.

Drives the same few-step DMD unroll the student was trained/served with
(NOT the standard denoising loop) so the collected activation statistics
are representative of inference. The VAE decode is skipped — calibration
only needs the transformer forwards.
"""
cfg = self.pipeline_manager.dmd_sampler_cfg
if cfg is None:
raise RuntimeError(
"DMD2 sampler config is not set; the qwen-image-dmd2 pipeline must be created "
"via PipelineManager.create_pipeline() before calibration."
)
dmd2_sample(self.pipe, prompt_batch, decode=False, **cfg)

def _run_wan_video_calibration(
self, prompt_batch: list[str], extra_args: dict[str, Any]
) -> None:
Expand Down
19 changes: 19 additions & 0 deletions examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class ModelType(str, Enum):
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"
QWEN_IMAGE = "qwen-image"
# DMD2-distilled few-step Qwen-Image student (from examples/diffusers/fastgen).
# Same architecture as QWEN_IMAGE, but loaded from a consolidated student dir
# and calibrated with the few-step DMD sampler instead of the standard loop.
QWEN_IMAGE_DMD2 = "qwen-image-dmd2"


_FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = {
Expand All @@ -74,6 +78,7 @@ class ModelType(str, Enum):
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
ModelType.QWEN_IMAGE: filter_func_qwen_image,
ModelType.QWEN_IMAGE_DMD2: filter_func_qwen_image,
}

_VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = {
Expand Down Expand Up @@ -107,6 +112,11 @@ def get_model_filter_func(
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
ModelType.QWEN_IMAGE: "Qwen/Qwen-Image",
# Base pipeline (VAE / text-encoder / tokenizer / scheduler) for DMD2 students;
# the trained transformer is loaded separately from a consolidated dir via the
# ``student_path`` extra-param. Override with ``--override-model-path`` or
# ``--extra-param base_pipeline_path=...``.
ModelType.QWEN_IMAGE_DMD2: "Qwen/Qwen-Image",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
Expand All @@ -122,6 +132,7 @@ def get_model_filter_func(
ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
ModelType.QWEN_IMAGE: QwenImagePipeline,
ModelType.QWEN_IMAGE_DMD2: QwenImagePipeline,
}

# Shared dataset configurations
Expand Down Expand Up @@ -258,6 +269,14 @@ def get_model_filter_func(
},
}

# DMD2 students share Qwen-Image's architecture, so they reuse the same block-range
# recipe, high-precision filter, base pipeline, and calibration dataset. They differ
# only in (a) loading -- a consolidated student dir swapped into the base pipeline
# (PipelineManager._create_qwen_image_dmd2_pipeline) -- and (b) calibration, which
# drives the few-step DMD sampler instead of the standard denoising loop
# (Calibrator._run_qwen_image_dmd2_calibration). Inherit so the recipe stays in sync.
MODEL_DEFAULTS[ModelType.QWEN_IMAGE_DMD2] = {**MODEL_DEFAULTS[ModelType.QWEN_IMAGE]}


def _coerce_extra_param_value(value: str) -> Any:
lowered = value.lower()
Expand Down
115 changes: 115 additions & 0 deletions examples/diffusers/quantization/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, config: ModelConfig, logger: logging.Logger):
self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling
self._transformer: torch.nn.Module | None = None
self._video_decoder: torch.nn.Module | None = None
# Few-step sampler config for DMD2 students (populated when loading a
# qwen-image-dmd2 pipeline); consumed by the calibrator / sanity check.
self.dmd_sampler_cfg: dict[str, Any] | None = None

@staticmethod
def create_pipeline_from(
Expand Down Expand Up @@ -100,6 +103,11 @@ def create_pipeline(self) -> Any:
self.logger.info("LTX-2 pipeline created successfully")
return self.pipe

if self.config.model_type == ModelType.QWEN_IMAGE_DMD2:
self.pipe = self._create_qwen_image_dmd2_pipeline()
self.logger.info("Qwen-Image DMD2 pipeline created successfully")
return self.pipe

pipeline_cls = MODEL_PIPELINE[self.config.model_type]
if pipeline_cls is None:
raise ValueError(
Expand Down Expand Up @@ -266,6 +274,113 @@ def _create_ltx2_pipeline(self) -> Any:
pipeline_kwargs.update(params)
return TI2VidTwoStagesPipeline(**pipeline_kwargs)

def _create_qwen_image_dmd2_pipeline(self) -> Any:
"""Build a QwenImagePipeline whose transformer is a DMD2-trained student.

Loads the consolidated student transformer (the ``model/consolidated`` dir
produced by ``examples/diffusers/fastgen`` training), optionally overlays an
EMA shadow, and swaps it into the base Qwen-Image pipeline so the VAE /
text-encoder / tokenizer / scheduler come from the base checkpoint.

Reads from ``extra_params``:
student_path (required): consolidated student dir.
base_pipeline_path: base Qwen-Image dir/HF id (defaults to the
registry id or ``--override-model-path``).
ema_path: optional ``ema_shadow.pt`` to overlay onto the student.
sample_steps / t_list / sample_type / guidance_scale / max_t:
few-step sampler schedule (defaults match the canonical 4-step
shift=3 student); stashed in ``self.dmd_sampler_cfg``.
"""
from qwen_image_dmd2_sampler import DEFAULT_MAX_T, resolve_schedule

try:
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
except ImportError as e:
raise ImportError(
"qwen-image-dmd2 requires a diffusers version providing QwenImagePipeline "
"and QwenImageTransformer2DModel; upgrade diffusers."
) from e

params = dict(self.config.extra_params)
student_path = params.get("student_path")
if not student_path:
raise ValueError(
"Missing required extra_param: student_path (the consolidated DMD2 student "
"dir, e.g. .../epoch_4_step_17999/model/consolidated)."
)
base_pipeline_path = params.get("base_pipeline_path") or self.config.model_path
ema_path = params.get("ema_path")

default_dtype = self.config.model_dtype["default"]
transformer_dtype = self.config.model_dtype.get("transformer", default_dtype)
if torch.float16 in (default_dtype, transformer_dtype):
self.logger.warning(
"Qwen-Image is trained/served in bfloat16; float16 (Half) can overflow the "
"VAE and produce NaNs. Consider --model-dtype BFloat16."
)

self.logger.info("Loading DMD2 student transformer from %s", student_path)
transformer = QwenImageTransformer2DModel.from_pretrained(
student_path, torch_dtype=transformer_dtype
)

if ema_path:
self.logger.info("Overlaying EMA shadow from %s", ema_path)
ema_state = torch.load(str(ema_path), map_location="cpu")
shadow = (
ema_state.get("shadow", ema_state) if isinstance(ema_state, dict) else ema_state
)
if not isinstance(shadow, dict):
raise ValueError(
f"ema_path content has unexpected type {type(shadow).__name__}; "
"expected dict[str, Tensor]."
)
missing, unexpected = transformer.load_state_dict(shadow, strict=False)
if unexpected:
self.logger.warning("EMA overlay had %d unexpected key(s)", len(unexpected))
if missing:
self.logger.warning("EMA overlay missed %d student key(s)", len(missing))

transformer.eval()

self.logger.info(
"Loading base Qwen-Image pipeline from %s (transformer replaced by student)",
base_pipeline_path,
)
pipe = QwenImagePipeline.from_pretrained(
base_pipeline_path, transformer=transformer, torch_dtype=default_dtype
)
pipe.set_progress_bar_config(disable=True)

# Resolve and stash the few-step sampler config. Defaults match the
# canonical 4-step shift=3 student; the schedule MUST match training.
sample_steps = params.get("sample_steps")
sample_steps = int(sample_steps) if sample_steps is not None else 4
t_list = params.get("t_list")
if isinstance(t_list, str):
t_list = [float(x) for x in t_list.split(",") if x.strip()]
max_t = float(params.get("max_t", DEFAULT_MAX_T))
schedule = resolve_schedule(t_list, sample_steps, max_t)
defaults = MODEL_DEFAULTS[self.config.model_type].get("inference_extra_args", {})
self.dmd_sampler_cfg = {
"schedule": schedule,
"sample_type": str(params.get("sample_type", "ode")),
"guidance_scale": float(params.get("guidance_scale", 1.0)),
"negative_prompt": params.get("negative_prompt"),
"height": int(params.get("height", defaults.get("height", 1024))),
"width": int(params.get("width", defaults.get("width", 1024))),
"max_sequence_length": int(params.get("max_sequence_length", 512)),
}
self.logger.info(
"DMD2 few-step sampler: steps=%d schedule=%s sample_type=%s guidance_scale=%s "
"(schedule must match the student's training t_list)",
len(schedule) - 1,
schedule,
self.dmd_sampler_cfg["sample_type"],
self.dmd_sampler_cfg["guidance_scale"],
)
return pipe

def print_quant_summary(self):
for name, backbone in self.iter_backbones():
self.logger.info(f"{name} quantization info:")
Expand Down
21 changes: 15 additions & 6 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
get_model_filter_func,
parse_extra_params,
)
from onnx_utils.export import generate_fp8_scales, modelopt_export_sd
from pipeline_manager import PipelineManager
from quantize_config import (
CalibrationConfig,
Expand All @@ -51,9 +50,8 @@
QuantFormat,
QuantizationConfig,
)
from utils import check_conv_and_mha, check_lora
from utils import check_conv_and_mha, check_lora, restore_quantizer_state, save_quantizer_state

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint

Expand Down Expand Up @@ -295,8 +293,11 @@ def save_checkpoint(
filename = f"{backbone_name}.pt" if backbone_name else "backbone.pt"
target_path = ckpt_path / filename

self.logger.info(f"Saving backbone to {target_path}")
mto.save(backbone, str(target_path))
# Save ONLY the quantization state (recipe + quantizer buffers incl. amax),
# not the model weights. The weights live in the base HF/diffusers checkpoint
# and are reloaded there on restore; this keeps the artifact tiny.
self.logger.info(f"Saving quantizer state (amax + recipe, no weights) to {target_path}")
save_quantizer_state(backbone, str(target_path))

self.logger.info("Checkpoint saved successfully")

Expand All @@ -319,6 +320,12 @@ def export_onnx(
if not self.config.onnx_dir:
return

# Imported lazily: the ONNX export tooling (onnx_graphsurgeon, etc.) is only
# needed when --onnx-dir is set, so the calibration + amax-save path still
# works in environments without the ONNX deps (e.g. the diffusers/fastgen
# container used for Qwen-Image DMD2 students).
from onnx_utils.export import generate_fp8_scales, modelopt_export_sd

self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}")

if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone):
Expand Down Expand Up @@ -362,7 +369,9 @@ def restore_checkpoint(self) -> None:
f"Checkpoint not found for '{backbone_name}' in {restore_path}"
)
self.logger.info(f"Restoring {backbone_name} from {source_path}")
mto.restore(backbone, str(source_path))
# The pipeline was just created with the base (unquantized) weights, so
# this re-applies the quantization recipe + amax on top of them.
restore_quantizer_state(backbone, str(source_path))

self.logger.info("Checkpoints restored successfully")

Expand Down
Loading
Loading