Add Qwen-Image DMD2 PTQ support; save quantizer state (amax) without weights#1827
Add Qwen-Image DMD2 PTQ support; save quantizer state (amax) without weights#1827jingyu-ml wants to merge 1 commit into
Conversation
…weights
Two related changes to the diffusers quantization example.
1) Weight-free quantizer-state checkpoint (all models)
The torch checkpoint written by `--quantized-torch-ckpt-save-path` now stores
ONLY ModelOpt's quantization state -- the recipe plus the quantizer buffers
(amax, pre_quant_scale, ...) -- and NOT the model weights. The weights live in
the base HF/diffusers checkpoint and are reloaded there on restore. This uses
ModelOpt's own idiom (mirrors plugins/transformers_trainer.py):
save: modelopt_state = mto.modelopt_state(model)
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(model)
torch.save(modelopt_state, path)
restore: modelopt_state = mto.load_modelopt_state(path)
weights = modelopt_state.pop("modelopt_state_weights", None)
mto.restore_from_modelopt_state(model, modelopt_state)
set_quantizer_state_dict(model, weights)
Wrapped as utils.save_quantizer_state / restore_quantizer_state and wired into
ExportManager.save_checkpoint / restore_checkpoint. Restore auto-applies on top
of the freshly-loaded base weights (the pipeline is created before restore).
Effect: the artifact drops from a full-model checkpoint to KBs-MBs (a 60-layer
Qwen-Image student: 40.8 GB -> 2.0 MB) while amax round-trips bit-identically.
2) qwen-image-dmd2 model type (DMD2 few-step Qwen-Image students)
For students distilled by examples/diffusers/fastgen. Reuses the existing
Qwen-Image quantization stack from the base branch (filter_func_qwen_image, the
block-range recipe, QwenImagePipeline registration) and adds only what differs:
- pipeline_manager: load the consolidated student transformer (+ optional EMA)
and swap it into the base QwenImagePipeline; stash the few-step sampler config
(defaults to the canonical 4-step shift=3 [1.0, 0.9, 0.75, 0.5, 0.0] ODE
schedule, guidance_scale=1.0).
- calibration: drive the few-step DMD sampler instead of the standard denoising
loop, so collected amax matches how the student is actually run.
- qwen_image_dmd2_sampler.py (new): vendored compact DMD unroll, bit-aligned with
fastgen/inference_dmd2_qwen_image.py; transformer forwards only for calibration
(decode=False), optional VAE decode for sanity inference (decode=True).
- sanity_check_dmd2.py (new): restore the quantizer-state checkpoint and run one
few-step inference to validate the round trip.
- quantize.py: import the ONNX export tooling lazily so the calibration + save
path runs without onnx_graphsurgeon (e.g. the diffusers/fastgen container).
Validated end-to-end (FP8, single GB200): calibrate -> save (2.0 MB) ->
restore_quantizer_state reproduces amax bit-identically (e.g. block 30
attn.to_q input amax 1.26e+03) and the restored quantized student renders a
finite, non-constant 1024x1024 image.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
📝 WalkthroughWalkthroughAdds a QWEN_IMAGE_DMD2 model path with dedicated pipeline creation, DMD2 sampling, quantizer-state checkpointing, calibration support, and a sanity-check CLI. ChangesQwen-Image DMD2 quantization flow
Sequence Diagram(s)sequenceDiagram
participant PipelineManager
participant Calibrator
participant dmd2_sample
participant QwenImagePipeline
PipelineManager->>QwenImagePipeline: build QWEN_IMAGE_DMD2 pipeline
Calibrator->>dmd2_sample: run calibration with decode=False
dmd2_sample->>QwenImagePipeline: encode prompts and unroll transformer steps
dmd2_sample-->>Calibrator: return without VAE decode
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Caution Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional.
❌ Failed checks (1 error, 1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## feature/qwen-image-svdquant-nvfp4 #1827 +/- ##
==================================================================
Coverage 76.41% 76.41%
==================================================================
Files 511 511
Lines 56751 56751
==================================================================
Hits 43365 43365
Misses 13386 13386
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/diffusers/quantization/qwen_image_dmd2_sampler.py`:
- Around line 81-98: resolve_schedule() currently allows invalid schedules,
including sample_steps=0 and t_list values that are not strictly descending
toward 0.0. Update resolve_schedule() to validate the interface inputs up front:
reject non-positive sample_steps, require t_list to end at 0.0 and be strictly
monotonic from max_t to 0.0, and raise a clear ValueError before returning any
schedule. Keep the checks localized in resolve_schedule() so downstream
calibration and sampling code can trust the schedule.
In `@examples/diffusers/quantization/sanity_check_dmd2.py`:
- Line 139: The NumPy import in the sanity_check_dmd2 flow is happening too late
inside the script, which delays dependency failures until after setup and
inference. Move the NumPy import to module scope near the other imports in
sanity_check_dmd2 so the dependency is checked immediately and the script
follows the repo’s import ordering rule.
In `@examples/diffusers/quantization/utils.py`:
- Around line 228-232: The restore helper currently assumes the new
tiny-checkpoint schema via mto.load_modelopt_state() and then only applies
modelopt_state_weights, which breaks older checkpoints saved with mto.save().
Update the restore path in the helper that loads from path to detect the legacy
checkpoint shape containing modelopt_state and model_state_dict, and route those
cases through the previous full-restore flow before calling
restore_from_modelopt_state or set_quantizer_state_dict. Keep the new
quantizer-state handling for the modern format, but preserve backward
compatibility for existing artifacts used by quantize.py --restore-from and
sanity_check_dmd2.py.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5c0994e8-a2fa-4138-9948-982da46a1f0b
📒 Files selected for processing (7)
examples/diffusers/quantization/calibration.pyexamples/diffusers/quantization/models_utils.pyexamples/diffusers/quantization/pipeline_manager.pyexamples/diffusers/quantization/quantize.pyexamples/diffusers/quantization/qwen_image_dmd2_sampler.pyexamples/diffusers/quantization/sanity_check_dmd2.pyexamples/diffusers/quantization/utils.py
| if t_list is not None: | ||
| schedule = [float(t) for t in t_list] | ||
| if abs(schedule[-1]) > 1e-6: | ||
| raise ValueError( | ||
| f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0." | ||
| ) | ||
| if sample_steps is not None and len(schedule) != sample_steps + 1: | ||
| raise ValueError( | ||
| f"t_list must have sample_steps+1 entries " | ||
| f"(got {len(schedule)} for sample_steps={sample_steps})." | ||
| ) | ||
| return schedule | ||
|
|
||
| if sample_steps == 1: | ||
| return [float(max_t), 0.0] | ||
| if sample_steps in (None, 4): | ||
| return list(DEFAULT_T_LIST) | ||
| return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist() |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Reject zero-step and non-monotonic schedules up front.
resolve_schedule() currently accepts sample_steps=0 and any t_list that merely ends in 0.0. That lets calibration run zero transformer forwards or march time in the wrong direction, which silently produces meaningless amax statistics and sanity outputs.
Suggested fix
def resolve_schedule(
t_list: list[float] | tuple[float, ...] | None,
sample_steps: int | None,
max_t: float = DEFAULT_MAX_T,
) -> list[float]:
@@
+ if sample_steps is not None and sample_steps < 1:
+ raise ValueError(f"sample_steps must be >= 1 (got {sample_steps}).")
+
if t_list is not None:
schedule = [float(t) for t in t_list]
+ if len(schedule) < 2:
+ raise ValueError("t_list must include at least one timestep plus a trailing 0.0.")
+ if any(t < 0.0 for t in schedule):
+ raise ValueError("t_list entries must be non-negative.")
+ if any(t_cur < t_next for t_cur, t_next in itertools.pairwise(schedule)):
+ raise ValueError("t_list must be monotonically non-increasing.")
if abs(schedule[-1]) > 1e-6:
raise ValueError(
f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0."
)As per coding guidelines, "Validate external input once at the interface boundary; internal code can trust those checks and avoid redundant assertions."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if t_list is not None: | |
| schedule = [float(t) for t in t_list] | |
| if abs(schedule[-1]) > 1e-6: | |
| raise ValueError( | |
| f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0." | |
| ) | |
| if sample_steps is not None and len(schedule) != sample_steps + 1: | |
| raise ValueError( | |
| f"t_list must have sample_steps+1 entries " | |
| f"(got {len(schedule)} for sample_steps={sample_steps})." | |
| ) | |
| return schedule | |
| if sample_steps == 1: | |
| return [float(max_t), 0.0] | |
| if sample_steps in (None, 4): | |
| return list(DEFAULT_T_LIST) | |
| return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist() | |
| if sample_steps is not None and sample_steps < 1: | |
| raise ValueError(f"sample_steps must be >= 1 (got {sample_steps}).") | |
| if t_list is not None: | |
| schedule = [float(t) for t in t_list] | |
| if len(schedule) < 2: | |
| raise ValueError("t_list must include at least one timestep plus a trailing 0.0.") | |
| if any(t < 0.0 for t in schedule): | |
| raise ValueError("t_list entries must be non-negative.") | |
| if any(t_cur < t_next for t_cur, t_next in itertools.pairwise(schedule)): | |
| raise ValueError("t_list must be monotonically non-increasing.") | |
| if abs(schedule[-1]) > 1e-6: | |
| raise ValueError( | |
| f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0." | |
| ) | |
| if sample_steps is not None and len(schedule) != sample_steps + 1: | |
| raise ValueError( | |
| f"t_list must have sample_steps+1 entries " | |
| f"(got {len(schedule)} for sample_steps={sample_steps})." | |
| ) | |
| return schedule | |
| if sample_steps == 1: | |
| return [float(max_t), 0.0] | |
| if sample_steps in (None, 4): | |
| return list(DEFAULT_T_LIST) | |
| return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist() |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/diffusers/quantization/qwen_image_dmd2_sampler.py` around lines 81 -
98, resolve_schedule() currently allows invalid schedules, including
sample_steps=0 and t_list values that are not strictly descending toward 0.0.
Update resolve_schedule() to validate the interface inputs up front: reject
non-positive sample_steps, require t_list to end at 0.0 and be strictly
monotonic from max_t to 0.0, and raise a clear ValueError before returning any
schedule. Keep the checks localized in resolve_schedule() so downstream
calibration and sampling code can trust the schedule.
Source: Coding guidelines
| images = dmd2_sample(pipe, [args.prompt], decode=True, generator=gen, **pm.dmd_sampler_cfg) | ||
| image = images[0] | ||
|
|
||
| import numpy as np |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟡 Minor
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the target file around the reported line and the top of the file.
file="examples/diffusers/quantization/sanity_check_dmd2.py"
wc -l "$file"
printf '\n--- top of file ---\n'
sed -n '1,80p' "$file"
printf '\n--- around line 139 ---\n'
sed -n '120,160p' "$file"Repository: NVIDIA/Model-Optimizer
Length of output: 4859
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check whether this script's NumPy import is module-level anywhere else in the file.
rg -n "import numpy as np|numpy as np|np\." examples/diffusers/quantization/sanity_check_dmd2.pyRepository: NVIDIA/Model-Optimizer
Length of output: 274
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# List nearby example scripts to compare import style if needed.
git ls-files examples/diffusers/quantization | sed -n '1,120p'Repository: NVIDIA/Model-Optimizer
Length of output: 1214
Move the NumPy import to module scope. examples/diffusers/quantization/sanity_check_dmd2.py:139 imports NumPy only after pipeline setup and inference, so a missing dependency fails late. Top-level import keeps the dependency check immediate and matches the repo import rule.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/diffusers/quantization/sanity_check_dmd2.py` at line 139, The NumPy
import in the sanity_check_dmd2 flow is happening too late inside the script,
which delays dependency failures until after setup and inference. Move the NumPy
import to module scope near the other imports in sanity_check_dmd2 so the
dependency is checked immediately and the script follows the repo’s import
ordering rule.
Source: Coding guidelines
| modelopt_state = mto.load_modelopt_state(str(path)) | ||
| quantizer_state = modelopt_state.pop("modelopt_state_weights", None) | ||
| mto.restore_from_modelopt_state(model, modelopt_state) | ||
| if quantizer_state is not None: | ||
| set_quantizer_state_dict(model, quantizer_state) |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win
Preserve the legacy full-checkpoint fallback here.
This helper now hard-requires the new quantizer-state schema via mto.load_modelopt_state(). Older checkpoints saved through mto.save() carry modelopt_state plus model_state_dict, so both quantize.py --restore-from ... and sanity_check_dmd2.py will fail instead of taking the backward-compatible restore path promised for existing artifacts.
Please detect the legacy checkpoint shape and fall back to the previous full-restore flow before assuming the new tiny-checkpoint format.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/diffusers/quantization/utils.py` around lines 228 - 232, The restore
helper currently assumes the new tiny-checkpoint schema via
mto.load_modelopt_state() and then only applies modelopt_state_weights, which
breaks older checkpoints saved with mto.save(). Update the restore path in the
helper that loads from path to detect the legacy checkpoint shape containing
modelopt_state and model_state_dict, and route those cases through the previous
full-restore flow before calling restore_from_modelopt_state or
set_quantizer_state_dict. Keep the new quantizer-state handling for the modern
format, but preserve backward compatibility for existing artifacts used by
quantize.py --restore-from and sanity_check_dmd2.py.
What does this PR do?
Type of change: new example feature (+ a general improvement to the diffusers quantization example)
Two related changes to
examples/diffusers/quantization:1) Weight-free quantizer-state checkpoint (all models).
save_checkpointnow stores only ModelOpt's quantization state — the recipe plus the quantizer buffers (amax, pre_quant_scale, …) — and not the model weights, which live in the base HF/diffusers checkpoint and are reloaded there on restore. This uses ModelOpt's own idiom (mirrorsmodelopt/torch/quantization/plugins/transformers_trainer.py):mto.modelopt_state(model)+get_quantizer_state_dict(model)(bundled under themodelopt_state_weightskey)mto.load_modelopt_state→mto.restore_from_modelopt_state→set_quantizer_state_dict, applied on top of the freshly-loaded base weights.The artifact drops from a full-model checkpoint to KBs–MBs (a 60-layer Qwen-Image student: 40.8 GB → 2.0 MB); amax and the enabled/disabled quantizer pattern round-trip bit-identically. Restore auto-falls-back to
mto.restorefor older full checkpoints.2)
qwen-image-dmd2model type for PTQ of DMD2-distilled few-step Qwen-Image students (trained byexamples/diffusers/fastgen). Reuses the existing Qwen-Image quant stack (filter_func_qwen_image, the block-range recipe,QwenImagePipelineregistration) and adds only what differs:QwenImagePipeline;[1.0, 0.9, 0.75, 0.5, 0.0]ODE schedule) instead of the standard denoising loop, so collected amax matches inference;qwen_image_dmd2_sampler.py: a compact DMD unroll vendored fromfastgen/inference_dmd2_qwen_image.py;sanity_check_dmd2.py: restore the checkpoint and run one few-step inference;onnx_graphsurgeon.Usage
Reload the saved quantizer state onto a model that holds its original weights:
Testing
Validated end-to-end on a single GB200 against a 60-layer Qwen-Image DMD2 student (FP8):
Gustavosta/Stable-Diffusion-Promptsdataset, 4-step ODE schedule[1.0, 0.9, 0.75, 0.5, 0.0].save→ 2.0 MB checkpoint →restorereproduces amax bit-identically (e.g. block 30attn.to_qinput amax1.26e+03) and the disabled-quantizer pattern (1528 disabled / 1792 enabled).Before your PR is "Ready for review"
mto.restorefor old full checkpoints.CONTRIBUTING.md: ✅ — the DMD sampler is vendored fromexamples/diffusers/fastgen/inference_dmd2_qwen_image.py(same repo, Apache-2.0); no new dependencies./claude review.Additional Information
Stacked on
feature/qwen-image-svdquant-nvfp4(PR #1706, which adds the base Qwen-Image support this builds on); base can be retargeted tomainonce #1706 merges.🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Bug Fixes