Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def _is_enabled_quantizer(quantizer):
return False


def _has_pre_quant_scale(module: nn.Module) -> bool:
input_quantizer = getattr(module, "input_quantizer", None)
return input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale")


def _save_component_state_dict_safetensors(
component: nn.Module,
component_export_dir: Path,
Expand Down Expand Up @@ -407,6 +412,7 @@ def _fuse_shared_input_modules(
and group_quant_format is not None
and group_quant_format != QUANTIZATION_NONE
and "awq" in group_quant_format
and all(_has_pre_quant_scale(module) for module in modules)
and tensor in output_to_layernorm
):
with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]):
Expand Down
22 changes: 21 additions & 1 deletion tests/unit/torch/export/test_unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import OrderedDict

import torch
from _test_utils.torch.export.utils import SmallQKVModel
from _test_utils.torch.quantization.tied_modules import (
make_tied_linear_pair,
wrap_in_parent_with_tied_keys,
Expand All @@ -29,7 +30,10 @@
_reorder_canonical_first,
)
from modelopt.torch.export.quant_utils import sync_tied_input_amax
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
from modelopt.torch.export.unified_export_hf import (
_export_quantized_weight,
requantize_resmooth_fused_llm_layers,
)


def test_collect_canonical_tied_patterns_dict_style():
Expand Down Expand Up @@ -115,6 +119,22 @@ def test_sync_tied_input_amax_no_op_for_untied_modules():
assert torch.allclose(dec_q.amax, torch.tensor(5.0))


def test_requantize_resmooth_skips_layernorm_without_pre_quant_scale():
"""INT4 blockwise weight-only export has no pre-quant scale to fold into layernorm."""
model = SmallQKVModel(dim=4, device="cpu", apply_embed=True)
mtq.quantize(model, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG)
modules = [model.q_proj, model.k_proj, model.v_proj]
layernorm_weight = model.input_layernorm.weight.detach().clone()

for module in modules:
assert not hasattr(module.input_quantizer, "_pre_quant_scale")

requantize_resmooth_fused_llm_layers(model)

assert torch.equal(model.input_layernorm.weight, layernorm_weight)
assert all(not getattr(module, "fused_with_prequant", False) for module in modules)


def _calibrate_through_both_children(parent):
"""Insert NVFP4 quantizers and run a one-shot forward through both children for calibration."""

Expand Down
Loading