Skip to content

Add Qwen-Image DMD2 PTQ support; save quantizer state (amax) without weights#1827

Open
jingyu-ml wants to merge 1 commit into
feature/qwen-image-svdquant-nvfp4from
feature/qwen-image-dmd2
Open

Add Qwen-Image DMD2 PTQ support; save quantizer state (amax) without weights#1827
jingyu-ml wants to merge 1 commit into
feature/qwen-image-svdquant-nvfp4from
feature/qwen-image-dmd2

Conversation

@jingyu-ml

@jingyu-ml jingyu-ml commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: new example feature (+ a general improvement to the diffusers quantization example)

Two related changes to examples/diffusers/quantization:

1) Weight-free quantizer-state checkpoint (all models).
save_checkpoint now stores only ModelOpt's quantization state — the recipe plus the quantizer buffers (amax, pre_quant_scale, …) — and not the model weights, which live in the base HF/diffusers checkpoint and are reloaded there on restore. This uses ModelOpt's own idiom (mirrors modelopt/torch/quantization/plugins/transformers_trainer.py):

  • save: mto.modelopt_state(model) + get_quantizer_state_dict(model) (bundled under the modelopt_state_weights key)
  • restore: mto.load_modelopt_statemto.restore_from_modelopt_stateset_quantizer_state_dict, applied on top of the freshly-loaded base weights.

The artifact drops from a full-model checkpoint to KBs–MBs (a 60-layer Qwen-Image student: 40.8 GB → 2.0 MB); amax and the enabled/disabled quantizer pattern round-trip bit-identically. Restore auto-falls-back to mto.restore for older full checkpoints.

2) qwen-image-dmd2 model type for PTQ of DMD2-distilled few-step Qwen-Image students (trained by examples/diffusers/fastgen). Reuses the existing Qwen-Image quant stack (filter_func_qwen_image, the block-range recipe, QwenImagePipeline registration) and adds only what differs:

  • load the consolidated student transformer (+ optional EMA) and swap it into the base QwenImagePipeline;
  • drive the few-step DMD sampler for calibration (the 4-step shift=3 [1.0, 0.9, 0.75, 0.5, 0.0] ODE schedule) instead of the standard denoising loop, so collected amax matches inference;
  • qwen_image_dmd2_sampler.py: a compact DMD unroll vendored from fastgen/inference_dmd2_qwen_image.py;
  • sanity_check_dmd2.py: restore the checkpoint and run one few-step inference;
  • ONNX export tooling is imported lazily so the calibration + save path runs without onnx_graphsurgeon.

Usage

python quantize.py --model qwen-image-dmd2 --format fp8 --model-dtype BFloat16 \
  --calib-size 64 \
  --extra-param student_path=/.../epoch_4_step_15999/model/consolidated \
  --extra-param base_pipeline_path=/.../Qwen-Image \
  --extra-param sample_steps=4 --extra-param t_list=1.0,0.9,0.75,0.5,0.0 \
  --quantized-torch-ckpt-save-path /.../quant/

Reload the saved quantizer state onto a model that holds its original weights:

import modelopt.torch.opt as mto
from modelopt.torch.quantization.utils.core_utils import set_quantizer_state_dict

model = QwenImageTransformer2DModel.from_pretrained(student_dir, torch_dtype=torch.bfloat16)
state = mto.load_modelopt_state("transformer.pt")
weights = state.pop("modelopt_state_weights", None)
mto.restore_from_modelopt_state(model, state)
set_quantizer_state_dict(model, weights)   # quantized model with calibrated amax

Testing

Validated end-to-end on a single GB200 against a 60-layer Qwen-Image DMD2 student (FP8):

  • Full 64-sample calibration on the default Gustavosta/Stable-Diffusion-Prompts dataset, 4-step ODE schedule [1.0, 0.9, 0.75, 0.5, 0.0].
  • save2.0 MB checkpoint → restore reproduces amax bit-identically (e.g. block 30 attn.to_q input amax 1.26e+03) and the disabled-quantizer pattern (1528 disabled / 1792 enabled).
  • The restored quantized student renders a coherent 1024×1024 image.

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ — new model type is additive; the example's torch-checkpoint format changes to weight-free, and restore falls back to mto.restore for old full checkpoints.
  • If you copied code or added a PIP dependency, did you follow CONTRIBUTING.md: ✅ — the DMD sampler is vendored from examples/diffusers/fastgen/inference_dmd2_qwen_image.py (same repo, Apache-2.0); no new dependencies.
  • Did you write any new necessary tests?: ❌ — an automated test needs the trained ~40 GB student checkpoint, which isn't available in CI; validated manually end-to-end on GPU (above). Happy to add a tiny-model unit test if desired.
  • Did you update Changelog?: N/A — example-only change.
  • Did you get Claude approval on this PR?: ❌ — pending /claude review.

Additional Information

Stacked on feature/qwen-image-svdquant-nvfp4 (PR #1706, which adds the base Qwen-Image support this builds on); base can be retargeted to main once #1706 merges.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added support for a new Qwen-Image DMD2 quantization workflow.
    • Introduced a few-step sampling path and a sanity-check script to verify restored checkpoints and generated output.
    • Added save/restore support for quantization state, enabling checkpoints to keep calibration details without full model weights.
  • Bug Fixes

    • Improved checkpoint handling so quantized models can be restored onto base weights more reliably.
    • Added validation for required sampling and configuration settings, with clearer failure behavior.

…weights

Two related changes to the diffusers quantization example.

1) Weight-free quantizer-state checkpoint (all models)
The torch checkpoint written by `--quantized-torch-ckpt-save-path` now stores
ONLY ModelOpt's quantization state -- the recipe plus the quantizer buffers
(amax, pre_quant_scale, ...) -- and NOT the model weights. The weights live in
the base HF/diffusers checkpoint and are reloaded there on restore. This uses
ModelOpt's own idiom (mirrors plugins/transformers_trainer.py):

    save:    modelopt_state = mto.modelopt_state(model)
             modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(model)
             torch.save(modelopt_state, path)
    restore: modelopt_state = mto.load_modelopt_state(path)
             weights = modelopt_state.pop("modelopt_state_weights", None)
             mto.restore_from_modelopt_state(model, modelopt_state)
             set_quantizer_state_dict(model, weights)

Wrapped as utils.save_quantizer_state / restore_quantizer_state and wired into
ExportManager.save_checkpoint / restore_checkpoint. Restore auto-applies on top
of the freshly-loaded base weights (the pipeline is created before restore).
Effect: the artifact drops from a full-model checkpoint to KBs-MBs (a 60-layer
Qwen-Image student: 40.8 GB -> 2.0 MB) while amax round-trips bit-identically.

2) qwen-image-dmd2 model type (DMD2 few-step Qwen-Image students)
For students distilled by examples/diffusers/fastgen. Reuses the existing
Qwen-Image quantization stack from the base branch (filter_func_qwen_image, the
block-range recipe, QwenImagePipeline registration) and adds only what differs:

- pipeline_manager: load the consolidated student transformer (+ optional EMA)
  and swap it into the base QwenImagePipeline; stash the few-step sampler config
  (defaults to the canonical 4-step shift=3 [1.0, 0.9, 0.75, 0.5, 0.0] ODE
  schedule, guidance_scale=1.0).
- calibration: drive the few-step DMD sampler instead of the standard denoising
  loop, so collected amax matches how the student is actually run.
- qwen_image_dmd2_sampler.py (new): vendored compact DMD unroll, bit-aligned with
  fastgen/inference_dmd2_qwen_image.py; transformer forwards only for calibration
  (decode=False), optional VAE decode for sanity inference (decode=True).
- sanity_check_dmd2.py (new): restore the quantizer-state checkpoint and run one
  few-step inference to validate the round trip.
- quantize.py: import the ONNX export tooling lazily so the calibration + save
  path runs without onnx_graphsurgeon (e.g. the diffusers/fastgen container).

Validated end-to-end (FP8, single GB200): calibrate -> save (2.0 MB) ->
restore_quantizer_state reproduces amax bit-identically (e.g. block 30
attn.to_q input amax 1.26e+03) and the restored quantized student renders a
finite, non-constant 1024x1024 image.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from a team as a code owner June 25, 2026 19:43
@jingyu-ml jingyu-ml requested review from chadvoegele and removed request for a team June 25, 2026 19:43
@copy-pr-bot

copy-pr-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a QWEN_IMAGE_DMD2 model path with dedicated pipeline creation, DMD2 sampling, quantizer-state checkpointing, calibration support, and a sanity-check CLI.

Changes

Qwen-Image DMD2 quantization flow

Layer / File(s) Summary
Model routing and pipeline creation
examples/diffusers/quantization/models_utils.py, examples/diffusers/quantization/pipeline_manager.py
Extends ModelType and pipeline selection for QWEN_IMAGE_DMD2, builds the DMD2 Qwen-Image pipeline, and stores sampler config for later use.
Quantizer state checkpointing
examples/diffusers/quantization/utils.py, examples/diffusers/quantization/quantize.py
Adds quantizer-state save/restore helpers, switches checkpoint persistence to save and reapply quantization state, and makes ONNX export imports lazy.
Sampler schedule and conditioning
examples/diffusers/quantization/qwen_image_dmd2_sampler.py
Defines DMD2 schedule resolution, prompt conditioning, and latent setup for sampling.
Sampler unroll and decode
examples/diffusers/quantization/qwen_image_dmd2_sampler.py
Runs the few-step transformer loop, applies guidance, and returns either calibration output or decoded images.
Calibration and sanity check entrypoints
examples/diffusers/quantization/calibration.py, examples/diffusers/quantization/sanity_check_dmd2.py
Routes calibration through dmd2_sample and adds a CLI sanity check that restores quantization state and writes image/statistics outputs.

Sequence Diagram(s)

sequenceDiagram
  participant PipelineManager
  participant Calibrator
  participant dmd2_sample
  participant QwenImagePipeline
  PipelineManager->>QwenImagePipeline: build QWEN_IMAGE_DMD2 pipeline
  Calibrator->>dmd2_sample: run calibration with decode=False
  dmd2_sample->>QwenImagePipeline: encode prompts and unroll transformer steps
  dmd2_sample-->>Calibrator: return without VAE decode
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • jenchen13
  • ChenhanYu
  • hychiang-git

Caution

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

  • Ignore

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error pipeline_manager.py loads ema_path with torch.load(...) and no inline safety comment; SECURITY.md forbids untrusted torch.load without justification. Use torch.load(..., weights_only=True) for the EMA tensor dict, or add a documented inline comment proving the file is internally generated and trusted.
Docstring Coverage ⚠️ Warning Docstring coverage is 68.42% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the two main changes: Qwen-Image DMD2 PTQ support and saving quantizer state without model weights.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/qwen-image-dmd2

Comment @coderabbitai help to get the list of available commands.

@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.41%. Comparing base (daf6d0f) to head (a689f31).

Additional details and impacted files
@@                        Coverage Diff                         @@
##           feature/qwen-image-svdquant-nvfp4    #1827   +/-   ##
==================================================================
  Coverage                              76.41%   76.41%           
==================================================================
  Files                                    511      511           
  Lines                                  56751    56751           
==================================================================
  Hits                                   43365    43365           
  Misses                                 13386    13386           
Flag Coverage Δ
unit 54.66% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/diffusers/quantization/qwen_image_dmd2_sampler.py`:
- Around line 81-98: resolve_schedule() currently allows invalid schedules,
including sample_steps=0 and t_list values that are not strictly descending
toward 0.0. Update resolve_schedule() to validate the interface inputs up front:
reject non-positive sample_steps, require t_list to end at 0.0 and be strictly
monotonic from max_t to 0.0, and raise a clear ValueError before returning any
schedule. Keep the checks localized in resolve_schedule() so downstream
calibration and sampling code can trust the schedule.

In `@examples/diffusers/quantization/sanity_check_dmd2.py`:
- Line 139: The NumPy import in the sanity_check_dmd2 flow is happening too late
inside the script, which delays dependency failures until after setup and
inference. Move the NumPy import to module scope near the other imports in
sanity_check_dmd2 so the dependency is checked immediately and the script
follows the repo’s import ordering rule.

In `@examples/diffusers/quantization/utils.py`:
- Around line 228-232: The restore helper currently assumes the new
tiny-checkpoint schema via mto.load_modelopt_state() and then only applies
modelopt_state_weights, which breaks older checkpoints saved with mto.save().
Update the restore path in the helper that loads from path to detect the legacy
checkpoint shape containing modelopt_state and model_state_dict, and route those
cases through the previous full-restore flow before calling
restore_from_modelopt_state or set_quantizer_state_dict. Keep the new
quantizer-state handling for the modern format, but preserve backward
compatibility for existing artifacts used by quantize.py --restore-from and
sanity_check_dmd2.py.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5c0994e8-a2fa-4138-9948-982da46a1f0b

📥 Commits

Reviewing files that changed from the base of the PR and between daf6d0f and a689f31.

📒 Files selected for processing (7)
  • examples/diffusers/quantization/calibration.py
  • examples/diffusers/quantization/models_utils.py
  • examples/diffusers/quantization/pipeline_manager.py
  • examples/diffusers/quantization/quantize.py
  • examples/diffusers/quantization/qwen_image_dmd2_sampler.py
  • examples/diffusers/quantization/sanity_check_dmd2.py
  • examples/diffusers/quantization/utils.py

Comment on lines +81 to +98
if t_list is not None:
schedule = [float(t) for t in t_list]
if abs(schedule[-1]) > 1e-6:
raise ValueError(
f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0."
)
if sample_steps is not None and len(schedule) != sample_steps + 1:
raise ValueError(
f"t_list must have sample_steps+1 entries "
f"(got {len(schedule)} for sample_steps={sample_steps})."
)
return schedule

if sample_steps == 1:
return [float(max_t), 0.0]
if sample_steps in (None, 4):
return list(DEFAULT_T_LIST)
return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Reject zero-step and non-monotonic schedules up front.

resolve_schedule() currently accepts sample_steps=0 and any t_list that merely ends in 0.0. That lets calibration run zero transformer forwards or march time in the wrong direction, which silently produces meaningless amax statistics and sanity outputs.

Suggested fix
 def resolve_schedule(
     t_list: list[float] | tuple[float, ...] | None,
     sample_steps: int | None,
     max_t: float = DEFAULT_MAX_T,
 ) -> list[float]:
@@
+    if sample_steps is not None and sample_steps < 1:
+        raise ValueError(f"sample_steps must be >= 1 (got {sample_steps}).")
+
     if t_list is not None:
         schedule = [float(t) for t in t_list]
+        if len(schedule) < 2:
+            raise ValueError("t_list must include at least one timestep plus a trailing 0.0.")
+        if any(t < 0.0 for t in schedule):
+            raise ValueError("t_list entries must be non-negative.")
+        if any(t_cur < t_next for t_cur, t_next in itertools.pairwise(schedule)):
+            raise ValueError("t_list must be monotonically non-increasing.")
         if abs(schedule[-1]) > 1e-6:
             raise ValueError(
                 f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0."
             )

As per coding guidelines, "Validate external input once at the interface boundary; internal code can trust those checks and avoid redundant assertions."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if t_list is not None:
schedule = [float(t) for t in t_list]
if abs(schedule[-1]) > 1e-6:
raise ValueError(
f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0."
)
if sample_steps is not None and len(schedule) != sample_steps + 1:
raise ValueError(
f"t_list must have sample_steps+1 entries "
f"(got {len(schedule)} for sample_steps={sample_steps})."
)
return schedule
if sample_steps == 1:
return [float(max_t), 0.0]
if sample_steps in (None, 4):
return list(DEFAULT_T_LIST)
return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist()
if sample_steps is not None and sample_steps < 1:
raise ValueError(f"sample_steps must be >= 1 (got {sample_steps}).")
if t_list is not None:
schedule = [float(t) for t in t_list]
if len(schedule) < 2:
raise ValueError("t_list must include at least one timestep plus a trailing 0.0.")
if any(t < 0.0 for t in schedule):
raise ValueError("t_list entries must be non-negative.")
if any(t_cur < t_next for t_cur, t_next in itertools.pairwise(schedule)):
raise ValueError("t_list must be monotonically non-increasing.")
if abs(schedule[-1]) > 1e-6:
raise ValueError(
f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0."
)
if sample_steps is not None and len(schedule) != sample_steps + 1:
raise ValueError(
f"t_list must have sample_steps+1 entries "
f"(got {len(schedule)} for sample_steps={sample_steps})."
)
return schedule
if sample_steps == 1:
return [float(max_t), 0.0]
if sample_steps in (None, 4):
return list(DEFAULT_T_LIST)
return torch.linspace(float(max_t), 0.0, sample_steps + 1).tolist()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/diffusers/quantization/qwen_image_dmd2_sampler.py` around lines 81 -
98, resolve_schedule() currently allows invalid schedules, including
sample_steps=0 and t_list values that are not strictly descending toward 0.0.
Update resolve_schedule() to validate the interface inputs up front: reject
non-positive sample_steps, require t_list to end at 0.0 and be strictly
monotonic from max_t to 0.0, and raise a clear ValueError before returning any
schedule. Keep the checks localized in resolve_schedule() so downstream
calibration and sampling code can trust the schedule.

Source: Coding guidelines

images = dmd2_sample(pipe, [args.prompt], decode=True, generator=gen, **pm.dmd_sampler_cfg)
image = images[0]

import numpy as np

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the target file around the reported line and the top of the file.
file="examples/diffusers/quantization/sanity_check_dmd2.py"
wc -l "$file"
printf '\n--- top of file ---\n'
sed -n '1,80p' "$file"
printf '\n--- around line 139 ---\n'
sed -n '120,160p' "$file"

Repository: NVIDIA/Model-Optimizer

Length of output: 4859


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check whether this script's NumPy import is module-level anywhere else in the file.
rg -n "import numpy as np|numpy as np|np\." examples/diffusers/quantization/sanity_check_dmd2.py

Repository: NVIDIA/Model-Optimizer

Length of output: 274


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# List nearby example scripts to compare import style if needed.
git ls-files examples/diffusers/quantization | sed -n '1,120p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1214


Move the NumPy import to module scope. examples/diffusers/quantization/sanity_check_dmd2.py:139 imports NumPy only after pipeline setup and inference, so a missing dependency fails late. Top-level import keeps the dependency check immediate and matches the repo import rule.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/diffusers/quantization/sanity_check_dmd2.py` at line 139, The NumPy
import in the sanity_check_dmd2 flow is happening too late inside the script,
which delays dependency failures until after setup and inference. Move the NumPy
import to module scope near the other imports in sanity_check_dmd2 so the
dependency is checked immediately and the script follows the repo’s import
ordering rule.

Source: Coding guidelines

Comment on lines +228 to +232
modelopt_state = mto.load_modelopt_state(str(path))
quantizer_state = modelopt_state.pop("modelopt_state_weights", None)
mto.restore_from_modelopt_state(model, modelopt_state)
if quantizer_state is not None:
set_quantizer_state_dict(model, quantizer_state)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Preserve the legacy full-checkpoint fallback here.

This helper now hard-requires the new quantizer-state schema via mto.load_modelopt_state(). Older checkpoints saved through mto.save() carry modelopt_state plus model_state_dict, so both quantize.py --restore-from ... and sanity_check_dmd2.py will fail instead of taking the backward-compatible restore path promised for existing artifacts.

Please detect the legacy checkpoint shape and fall back to the previous full-restore flow before assuming the new tiny-checkpoint format.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/diffusers/quantization/utils.py` around lines 228 - 232, The restore
helper currently assumes the new tiny-checkpoint schema via
mto.load_modelopt_state() and then only applies modelopt_state_weights, which
breaks older checkpoints saved with mto.save(). Update the restore path in the
helper that loads from path to detect the legacy checkpoint shape containing
modelopt_state and model_state_dict, and route those cases through the previous
full-restore flow before calling restore_from_modelopt_state or
set_quantizer_state_dict. Keep the new quantizer-state handling for the modern
format, but preserve backward compatibility for existing artifacts used by
quantize.py --restore-from and sanity_check_dmd2.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant