Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Changelog
**Bug Fixes**

- Fix ``ShapeInferenceError`` during ONNX INT8 + FP16 quantization (``--high_precision_dtype fp16``) of weakly-typed models (e.g. TensorFlow exports) that carry stale rank-0 ``graph.output`` shapes or ops such as ``TopK`` that ONNX's static shape inference cannot resolve. ``clear_stale_value_info`` now reconciles stale output shapes via symbolic shape inference (keeping every output's shape field populated), and AutoCast runs ONNX shape inference in strict mode and falls back to schema-based standalone type inference when it fails, so unresolved ops no longer leave tensors untyped.
- Fused MoE expert auto-detection (``register_fused_experts_on_the_fly``) no longer requires an ``act_fn`` attribute. Some fused-expert modules (e.g. ``MiniMaxM3VLExperts``) apply a custom gated activation between the two ``F.linear`` calls instead of exposing ``act_fn``; they were silently skipped, leaving routed experts unquantized (an experts-only recipe matched nothing) and failing HF export with ``NotImplementedError``. ``_QuantFusedExperts`` is activation-agnostic (it only intercepts the two ``F.linear`` calls), so the requirement was unnecessary. This enables NVFP4/FP8 quantization and export for MiniMax-M2 / MiniMax-M3.

0.45 (2026-07-02)
^^^^^^^^^^^^^^^^^
Expand Down
257 changes: 257 additions & 0 deletions modelopt/torch/export/quant_aware_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Quantization-aware reverse weight conversion for unified HF export.

Background
----------
``transformers`` may apply a ``conversion_mapping`` when loading a model, so the
in-memory parameter names differ from the original model-hub checkpoint (e.g. fused
``mlp.gate_up_proj``, renamed MoE leaves, reordered ``model``/``language_model``
prefix). On save, ``transformers`` reverses this via ``revert_weight_conversion`` so
the on-disk names match the hub checkpoint again.

ModelOpt's unified export disables that reverse (it raises ``IndexError`` on 0-d
scalar scale tensors such as ``weight_scale_2``/``input_scale``), so a quantized
export emits the *in-memory* (post-conversion) names — violating the unified
checkpoint contract that names stay aligned with the original hub checkpoint.

This module performs the reverse in a quantization-aware way: it carries each
weight's companion scale tensors (``weight_scale``, ``weight_scale_2``,
``input_scale``, ``weight_scale_inv``, ``bias``) through the rename and un-fuse
operations.

Scope
-----
Two reverse primitives cover the common conversion_mapping cases:

* **Rename** — a key-level string substitution. Because a quantized linear stores
every tensor under ``<module>.<leaf>``, renaming the module substring rewrites the
weight and all its scale siblings together with no tensor manipulation.
* **Split** — un-fuse an output-dim concatenation (e.g. ``gate_up_proj`` ->
``gate_proj`` + ``up_proj``). ``weight``/``weight_scale``/``weight_scale_inv``/
``bias`` are chunked along the fused (output) dim; 0-d scalar ``weight_scale_2``/
``input_scale`` are duplicated to each part (they are per-tensor and shared).

The 3-D stacked-expert case (``MergeModulelist``, where per-expert weights are
stacked into ``experts.gate_up_proj`` with leading expert dim) is intentionally
*not* handled here: the stacked-scalar-scale layout cannot be validated against a
published checkpoint yet. Encountering it raises :class:`QuantConversionUnsupportedError`
so the caller can fall back to the legacy (in-memory-name) behavior rather than
emit a silently-wrong checkpoint. See the module TODO.
"""

import re
from dataclasses import dataclass

import torch

__all__ = [
"QuantConversionUnsupportedError",
"RenameRule",
"SplitRule",
"apply_reverse_rules",
"revert_weight_conversion_quant_aware",
]

# Tensor leaves that belong to a single quantized linear module. A rename of the
# parent module path applies uniformly to all of these.
_LEAF_SUFFIXES = (
".weight",
".weight_scale",
".weight_scale_2",
".weight_scale_inv",
".input_scale",
".bias",
)

# Leaves that are per-tensor scalars (0-d) and must be *duplicated*, not split, when
# a fused module is un-fused.
_SCALAR_LEAF_SUFFIXES = (".weight_scale_2", ".input_scale")


class QuantConversionUnsupportedError(Exception):
"""Raised when a conversion op cannot be reversed quant-aware (caller falls back)."""


@dataclass(frozen=True)
class RenameRule:
"""Reverse of a ``WeightRenaming``: ``re.sub(pattern, repl, key)`` on every key."""

pattern: str
repl: str


@dataclass(frozen=True)
class SplitRule:
"""Reverse of an output-dim ``Concatenate``: un-fuse one module into ``parts``.

Args:
fused_suffix: module suffix of the fused tensor, e.g. ``".gate_up_proj"``.
part_suffixes: ordered replacements, e.g. ``(".gate_proj", ".up_proj")``.
dim: the fused (output) dim along which ``weight``/``weight_scale``/``bias``
are chunked. NVFP4 ``weight`` is ``[out, in//2]`` and ``weight_scale`` is
``[out, in//block]`` so the output dim is ``0`` for both.
"""

fused_suffix: str
part_suffixes: tuple[str, ...]
dim: int = 0


def _split_leaf_tensor(leaf: str, tensor: torch.Tensor, n: int, idx: int, dim: int):
"""Return the ``idx``-th of ``n`` parts of ``tensor`` for tensor leaf ``leaf``."""
if leaf in _SCALAR_LEAF_SUFFIXES or tensor.dim() == 0:
# Per-tensor scalar shared across the fused parts -> duplicate.
return tensor.clone()
size = tensor.size(dim)
if size % n != 0:
raise QuantConversionUnsupportedError(
f"cannot split leaf '{leaf}' of size {size} along dim {dim} into {n} parts"
)
return tensor.chunk(n, dim=dim)[idx].clone()


def _apply_split_rule(state_dict: dict[str, torch.Tensor], rule: SplitRule) -> None:
"""Un-fuse all modules matching ``rule.fused_suffix`` in place."""
n = len(rule.part_suffixes)
# Collect (module_path, leaf, key) for every tensor under a fused module.
fused_keys: list[tuple[str, str, str]] = []
for key in state_dict:
for leaf in _LEAF_SUFFIXES:
if key.endswith(rule.fused_suffix + leaf):
module = key[: -len(leaf)][: -len(rule.fused_suffix)]
fused_keys.append((module, leaf, key))
break

for module, leaf, key in fused_keys:
tensor = state_dict.pop(key)
# A 3-D expert tensor here means stacked experts (MergeModulelist) — out of scope.
if leaf == ".weight" and tensor.dim() >= 3:
raise QuantConversionUnsupportedError(
f"stacked 3-D expert tensor '{key}' (ndim={tensor.dim()}) is not supported; "
"un-stacking experts + their scales is a follow-up"
)
for idx, part in enumerate(rule.part_suffixes):
state_dict[module + part + leaf] = _split_leaf_tensor(leaf, tensor, n, idx, rule.dim)


def apply_reverse_rules(
state_dict: dict[str, torch.Tensor],
split_rules: list[SplitRule],
rename_rules: list[RenameRule],
) -> dict[str, torch.Tensor]:
"""Apply quant-aware reverse conversion: splits first, then renames.

Splits run on the in-memory (post-conversion) names; renames then map the
resulting keys back to the original hub names. Renames are applied in order.
"""
out = dict(state_dict)
for rule in split_rules:
_apply_split_rule(out, rule)

compiled = [(re.compile(r.pattern), r.repl) for r in rename_rules]
renamed: dict[str, torch.Tensor] = {}
for key, value in out.items():
new_key = key
for pattern, repl in compiled:
new_key = pattern.sub(repl, new_key)
if new_key in renamed:
raise QuantConversionUnsupportedError(f"rename collision on '{new_key}'")
renamed[new_key] = value
return renamed


def revert_weight_conversion_quant_aware(model, state_dict: dict[str, torch.Tensor]):
"""Reverse a transformers conversion_mapping on a quantized state dict.

Builds reverse rules from the model's conversion mapping and applies them
carrying companion scale tensors. Raises :class:`QuantConversionUnsupportedError`
when the mapping uses an op that cannot be reversed quant-aware yet, so the
caller can fall back to the legacy behavior.
"""
split_rules, rename_rules = _build_reverse_rules(model)
if not split_rules and not rename_rules:
return state_dict
return apply_reverse_rules(state_dict, split_rules, rename_rules)


def _build_reverse_rules(model) -> tuple[list[SplitRule], list[RenameRule]]:
"""Best-effort: derive reverse rules from the model's transformers conversion mapping.

Returns empty rule lists when no mapping applies (then the export is unchanged).
Raises :class:`QuantConversionUnsupportedError` for ops not yet handled quant-aware
(e.g. stacked-expert ``MergeModulelist``), so the caller falls back safely.
"""
try:
conversions = getattr(model, "_weight_conversions", None)
if conversions is None:
from transformers.conversion_mapping import get_model_conversion_mapping

conversions = get_model_conversion_mapping(model, add_legacy=False)
except Exception as exc: # transformers without conversion_mapping, or API drift
raise QuantConversionUnsupportedError(f"could not read conversion mapping: {exc}") from exc

if not conversions:
return [], []

from transformers.core_model_loading import (
Concatenate,
MergeModulelist,
WeightConverter,
WeightRenaming,
)

split_rules: list[SplitRule] = []
rename_rules: list[RenameRule] = []
for conv in conversions:
if isinstance(conv, WeightRenaming):
# source -> target on load; reverse maps target -> source on save.
rename_rules.append(
RenameRule(pattern=re.escape(conv.target_patterns), repl=conv.source_patterns)
)
elif isinstance(conv, WeightConverter):
ops = list(conv.operations)
if any(isinstance(op, MergeModulelist) for op in ops):
raise QuantConversionUnsupportedError(
"stacked-expert MergeModulelist conversion is not yet reversible quant-aware"
)
if len(ops) == 1 and isinstance(ops[0], Concatenate):
split_rules.append(_concat_to_split_rule(conv, ops[0]))
else:
raise QuantConversionUnsupportedError(
f"unsupported converter operations: {[type(o).__name__ for o in ops]}"
)
else:
raise QuantConversionUnsupportedError(f"unsupported conversion entry: {type(conv).__name__}")
return split_rules, rename_rules


def _concat_to_split_rule(conv, concat) -> SplitRule:
"""Translate a fusing ``Concatenate`` converter into a :class:`SplitRule`."""
fused = _suffix(conv.target_patterns)
parts = tuple(_suffix(p) for p in conv.source_patterns)
return SplitRule(fused_suffix=fused, part_suffixes=parts, dim=concat.dim)


def _suffix(pattern: str) -> str:
"""Module suffix from a conversion pattern, e.g. ``.experts.*.w1.weight`` -> ``.w1``."""
p = pattern
for leaf in _LEAF_SUFFIXES:
if p.endswith(leaf):
p = p[: -len(leaf)]
break
leaf = p.rsplit(".", 1)[-1]
return "." + leaf
31 changes: 26 additions & 5 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@
from .model_utils import _reorder_canonical_first, get_language_model_from_vl, is_multimodal_model
from .moe_utils import _export_fused_experts
from .plugins import SpeculativeDecodingExporter, has_spec_opt, sanitize_hf_config_for_deployment
from .quant_aware_conversion import (
QuantConversionUnsupportedError,
revert_weight_conversion_quant_aware,
)
from .quant_utils import (
fuse_prequant_layernorm,
fuse_prequant_to_linear,
Expand Down Expand Up @@ -1476,19 +1480,36 @@ def export_hf_checkpoint(
if getattr(model, "hf_quantizer", None) is not None:
model.hf_quantizer = None

export_state_dict = {**post_state_dict, **(extra_state_dict or {})}

# transformers may have applied a load-time conversion_mapping (fused gate_up_proj,
# renamed MoE leaves, reordered model/language_model prefix), so the in-memory names
# differ from the original hub checkpoint. Reverse it quantization-aware so exported
# tensor names stay aligned with the hub checkpoint (the unified-checkpoint contract).
# transformers' own revert_weight_conversion errors on 0-d scalar scale tensors, so we
# do the reverse here; for any op we cannot reverse yet (e.g. stacked-expert fusion)
# we fall back to the in-memory names.
try:
export_state_dict = revert_weight_conversion_quant_aware(model, export_state_dict)
except QuantConversionUnsupportedError as exc:
warnings.warn(
f"Quant-aware reverse weight conversion skipped ({exc}); exported tensor "
"names may not match the original HF hub checkpoint."
)

# Save model
# Temporarily disable revert_weight_conversion if available — it doesn't handle
# quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
# We must patch both the source module and the importing module since
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
# Keep transformers' own revert_weight_conversion disabled (the quant-aware reverse
# above replaces it): it doesn't handle quantized state dicts (0-d scalar scale
# tensors cause IndexError). Patch both the source module and the importing module
# since modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

_sanitize_generation_config_for_save(model)

try:
model.save_pretrained(
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
state_dict=export_state_dict,
save_modelopt_state=save_modelopt_state,
max_shard_size=max_shard_size,
)
Expand Down
12 changes: 8 additions & 4 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,20 +1606,24 @@ def register_sparse_moe_on_the_fly(model):
def _fused_experts_wrapper_class(module):
"""Return the _QuantFusedExperts subclass for a fused MoE expert container, or None.

Two 3-D fused layouts are recognized, both requiring ``num_experts`` + ``act_fn``
and a 3-D ``down_proj`` parameter:
Two 3-D fused layouts are recognized, both requiring ``num_experts`` and a
3-D ``down_proj`` parameter:

* gated (``_QuantFusedExperts``): a 3-D ``gate_up_proj`` fusing gate+up. Matches
``MixtralExperts``, ``Qwen2MoeExperts``, ``Qwen3MoeExperts``,
``Qwen3_5MoeExperts``, ``DeepseekV3NaiveMoe``, ``JambaExperts``,
``OlmoeExperts``, etc.
``OlmoeExperts``, ``MiniMaxM2Experts``, ``MiniMaxM3VLExperts``, etc.
* non-gated (``_QuantNonGatedFusedExperts``): a 3-D ``up_proj`` with no
``gate_proj`` and no ``gate_up_proj``. Matches NemotronH ``NemotronHExperts``.

Returns ``None`` for non-standard layouts (DBRX, GptOss, GraniteMoE,
Llama4TextExperts) which have their own explicit registrations.

``act_fn`` is not required: these wrappers only intercept the two ``F.linear``
calls, so modules with a custom gated activation (e.g. ``MiniMaxM3VLExperts``)
are still supported.
"""
if not hasattr(module, "num_experts") or not hasattr(module, "act_fn"):
if not hasattr(module, "num_experts"):
return None
down = getattr(module, "down_proj", None)
if not isinstance(down, (nn.Parameter, Tensor)) or down.dim() != 3:
Expand Down
Loading
Loading