From e08b9133924b1e3768ff8352de8af499ac2476de Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 25 Jun 2026 10:54:39 -0700 Subject: [PATCH] Fix weight-only prequant layernorm export Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/unified_export_hf.py | 6 +++++ .../torch/export/test_unified_export_hf.py | 22 ++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 8bc92ed5eb9..72b35ce84f3 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -121,6 +121,11 @@ def _is_enabled_quantizer(quantizer): return False +def _has_pre_quant_scale(module: nn.Module) -> bool: + input_quantizer = getattr(module, "input_quantizer", None) + return input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale") + + def _save_component_state_dict_safetensors( component: nn.Module, component_export_dir: Path, @@ -407,6 +412,7 @@ def _fuse_shared_input_modules( and group_quant_format is not None and group_quant_format != QUANTIZATION_NONE and "awq" in group_quant_format + and all(_has_pre_quant_scale(module) for module in modules) and tensor in output_to_layernorm ): with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): diff --git a/tests/unit/torch/export/test_unified_export_hf.py b/tests/unit/torch/export/test_unified_export_hf.py index 3032353f914..a4e1859aa60 100644 --- a/tests/unit/torch/export/test_unified_export_hf.py +++ b/tests/unit/torch/export/test_unified_export_hf.py @@ -18,6 +18,7 @@ from collections import OrderedDict import torch +from _test_utils.torch.export.utils import SmallQKVModel from _test_utils.torch.quantization.tied_modules import ( make_tied_linear_pair, wrap_in_parent_with_tied_keys, @@ -29,7 +30,10 @@ _reorder_canonical_first, ) from modelopt.torch.export.quant_utils import sync_tied_input_amax -from modelopt.torch.export.unified_export_hf import _export_quantized_weight +from modelopt.torch.export.unified_export_hf import ( + _export_quantized_weight, + requantize_resmooth_fused_llm_layers, +) def test_collect_canonical_tied_patterns_dict_style(): @@ -115,6 +119,22 @@ def test_sync_tied_input_amax_no_op_for_untied_modules(): assert torch.allclose(dec_q.amax, torch.tensor(5.0)) +def test_requantize_resmooth_skips_layernorm_without_pre_quant_scale(): + """INT4 blockwise weight-only export has no pre-quant scale to fold into layernorm.""" + model = SmallQKVModel(dim=4, device="cpu", apply_embed=True) + mtq.quantize(model, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) + modules = [model.q_proj, model.k_proj, model.v_proj] + layernorm_weight = model.input_layernorm.weight.detach().clone() + + for module in modules: + assert not hasattr(module.input_quantizer, "_pre_quant_scale") + + requantize_resmooth_fused_llm_layers(model) + + assert torch.equal(model.input_layernorm.weight, layernorm_weight) + assert all(not getattr(module, "fused_with_prequant", False) for module in modules) + + def _calibrate_through_both_children(parent): """Insert NVFP4 quantizers and run a one-shot forward through both children for calibration."""