diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index d8ddf442924..1dbf93b3ab4 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -309,6 +309,77 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer) +def _to_export_tensor(tensor: torch.Tensor) -> torch.Tensor: + if hasattr(tensor, "to_local"): + tensor = tensor.to_local() + return tensor.detach().float() + + +def _laq_amax_to_scale( + amax: torch.Tensor, max_bound: float, min_value: float | torch.Tensor +) -> torch.Tensor: + scale = _to_export_tensor(amax) / max_bound + return torch.where(scale <= min_value, min_value, scale) + + +def _laq_scale_factors( + scale: torch.Tensor, + quantizer: TensorQuantizer, + quantize_scale: bool, + expected_shape: torch.Size, +) -> tuple[torch.Tensor, torch.Tensor]: + scale = scale.view(expected_shape) + if not quantize_scale: + return scale, torch.tensor(1.0, device=scale.device) + + scale_amax = _to_export_tensor(quantizer._per_tensor_scale) + scale_2 = scale_amax / 448.0 + scaled = (scale * 448.0 / scale_amax.view(-1)).clamp(min=2**-9, max=448.0) + return scaled.to(torch.float8_e4m3fn), scale_2 + + +def get_laq_weight_scaling_factors( + weight_quantizer: TensorQuantizer, weight: torch.Tensor, block_size: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Return LAQ pre/post scale factors for NVFP4 export. + + The pre factors are used only to pack the FP4 weight codes. The post factors + are serialized as the exported dequantization scales. + """ + assert getattr(weight_quantizer, "_laq", False), "Expected an LAQ quantizer." + assert weight.shape[-1] % block_size == 0, ( + "Weight shape is not divisible for block size for block quantization." + ) + + expected_shape = torch.Size((*weight.shape[:-1], weight.shape[-1] // block_size)) + quantize_scales = getattr(weight_quantizer, "_quantize_scales", False) + per_tensor_scale = ( + _to_export_tensor(weight_quantizer._per_tensor_scale) if quantize_scales else None + ) + max_bound = float(weight_quantizer._quant_max_bound) + + post_min = 0.002 * per_tensor_scale.view(-1) if per_tensor_scale is not None else 1e-8 + pre_min = ( + 0.002 * per_tensor_scale.view(-1) + if per_tensor_scale is not None and getattr(weight_quantizer, "_quantize_pre_scale", True) + else 1e-8 + ) + + post_scale = _laq_amax_to_scale(weight_quantizer.amax_post, max_bound, post_min) + pre_scale = _laq_amax_to_scale(weight_quantizer.amax_pre, max_bound, pre_min) + + pre_scale, pre_scale_2 = _laq_scale_factors( + pre_scale, + weight_quantizer, + quantize_scales and getattr(weight_quantizer, "_quantize_pre_scale", True), + expected_shape, + ) + post_scale, post_scale_2 = _laq_scale_factors( + post_scale, weight_quantizer, quantize_scales, expected_shape + ) + return pre_scale, pre_scale_2, post_scale, post_scale_2 + + def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") -> torch.Tensor: """Returns the secondary weight scaling factor.""" weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 00e4a7008a9..b144206bb53 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -95,6 +95,7 @@ fuse_prequant_layernorm, fuse_prequant_to_linear, get_activation_scaling_factor, + get_laq_weight_scaling_factors, get_quant_config, get_quantization_format, get_weight_block_size, @@ -537,6 +538,12 @@ def _export_quantized_weight( output_quantizer: TensorQuantizer | SequentialQuantizer | None = getattr( sub_module, quantizer_attrs.output_quantizer, None ) + is_laq_nvfp4 = getattr(weight_quantizer, "_laq", False) and quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, + ] if quantization_format == QUANTIZATION_FP8: # Convert amax to float32 @@ -587,6 +594,8 @@ def _export_quantized_weight( sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale) if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: del weight_quantizer._scale + elif is_laq_nvfp4: + pass else: sub_module.register_buffer( quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name) @@ -604,14 +613,18 @@ def _export_quantized_weight( ).squeeze(), ) - if quantization_format in [ - QUANTIZATION_NVFP4_AWQ, - QUANTIZATION_NVFP4_SVDQUANT, - QUANTIZATION_NVFP4, - QUANTIZATION_W4A16_NVFP4, - QUANTIZATION_W4A8_AWQ, - QUANTIZATION_W4A8_NVFP4_FP8, - ]: + if ( + quantization_format + in [ + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_NVFP4, + QUANTIZATION_W4A16_NVFP4, + QUANTIZATION_W4A8_AWQ, + QUANTIZATION_W4A8_NVFP4_FP8, + ] + and not is_laq_nvfp4 + ): # Register weight_scale_2 sub_module.register_buffer( quantizer_attrs.weight_scale_2, @@ -646,7 +659,11 @@ def _export_quantized_weight( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - if NVFP4QTensor._is_static_quantizer(weight_quantizer): + if is_laq_nvfp4: + weight_scale, weight_scale_2, export_weight_scale, export_weight_scale_2 = ( + get_laq_weight_scaling_factors(weight_quantizer, weight, block_size) + ) + elif NVFP4QTensor._is_static_quantizer(weight_quantizer): weight_scale = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( weight_quantizer, weight, @@ -666,10 +683,18 @@ def _export_quantized_weight( weight_scale_2, block_size, ) + if is_laq_nvfp4: + weight_scale = export_weight_scale + weight_scale_2 = export_weight_scale_2 quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions( quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight ) + if is_laq_nvfp4: + assert weight_scale is not None + assert weight_scale_2 is not None + sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) + sub_module.register_buffer(quantizer_attrs.weight_scale_2, weight_scale_2.squeeze()) elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight: # For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale) weight, _ = maybe_transpose_expert_weight_dimensions( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 118752df8f9..7731976ab39 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1012,6 +1012,9 @@ class LAQConfig(QuantizeAlgorithmConfig): ``tied_amax`` makes pre and post share a single tensor (requires both to have the same learnable state, i.e. ``learnable_amax`` must be ``["pre", "post"]`` or ``[]``). + + ``quantize_pre_scale=False`` leaves the pre-quantization scale unquantized + while preserving the existing post-scale quantization behavior. """ method: Literal["laq"] = ModeloptField("laq") @@ -1035,6 +1038,15 @@ class LAQConfig(QuantizeAlgorithmConfig): ), ) + quantize_pre_scale: bool = ModeloptField( + default=True, + title="FP8-quantize the LAQ pre-quantization scale.", + description=( + "If False, LAQ uses the raw pre-quantization scale while keeping post-scale " + "quantization controlled by the quantizer's block-scale settings." + ), + ) + scale_algorithm: dict | None = ModeloptField( default=None, title="Scale calibration algorithm to run first.", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9c39e795c92..763f36b7479 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -2016,6 +2016,7 @@ def laq( scale_algorithm: dict | None = None, learnable_amax: list | str = ("post",), tied_amax: bool = False, + quantize_pre_scale: bool = True, **kwargs, ): """Run scale calibration then convert to LAQ mode. @@ -2032,6 +2033,7 @@ def laq( learnable_amax: Which amax params are learnable: 'pre', 'post', ['pre', 'post'], or []. tied_amax: If True, pre and post share a single tensor. + quantize_pre_scale: If False, skip FP8 quantization for the LAQ pre scale. """ _run_scale_calibration(model, forward_loop, scale_algorithm, "laq") @@ -2047,4 +2049,5 @@ def laq( quantize_scales, learnable_amax=learnable_amax, tied_amax=tied_amax, + quantize_pre_scale=quantize_pre_scale, ) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 5e08d277ee0..590a1529bf3 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1393,6 +1393,7 @@ class StaticBlockScaleQuantizer(TensorQuantizer): _tied_amax: bool = False _quant_max_bound: float = 6.0 _quantize_scales: bool = True + _quantize_pre_scale: bool = True def _preserve_amax_in_fp32(self): for name in ("_amax", "_global_amax", "_per_tensor_scale"): @@ -1536,6 +1537,7 @@ def enable_laq( quantize_scales: bool = True, learnable_amax: list | str = ("post",), tied_amax: bool = False, + quantize_pre_scale: bool = True, ): """LAQ mode with configurable learnable/frozen amax tensors. @@ -1546,6 +1548,7 @@ def enable_laq( learnable_amax: Which amax params are learnable: 'pre', 'post', ['pre', 'post'], or []. tied_amax: If True, pre and post share a single tensor. + quantize_pre_scale: Whether to FP8-quantize the LAQ pre scale. """ if hasattr(self, "_amax"): delattr(self, "_amax") @@ -1572,6 +1575,7 @@ def enable_laq( if per_tensor_scale is not None: self.register_buffer("_per_tensor_scale", per_tensor_scale.float().clone().detach()) self._quantize_scales = quantize_scales + self._quantize_pre_scale = quantize_pre_scale self._laq = True self._learnable_amax = sorted(learn) self._tied_amax = tied_amax @@ -1592,22 +1596,29 @@ def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self._laq: # 0.002 ≈ smallest positive FP8 E4M3 value; clamps per-block scale floor - _scale_min = 0.002 * self._per_tensor_scale.view(-1) if self._quantize_scales else 1e-8 + scale_min_post = ( + 0.002 * self._per_tensor_scale.view(-1) if self._quantize_scales else 1e-8 + ) + scale_min_pre = ( + 0.002 * self._per_tensor_scale.view(-1) + if self._quantize_scales and self._quantize_pre_scale + else 1e-8 + ) scale_post = self._maybe_quantize_scale( _amax_to_scale( _to_local(self.amax_post), self._quant_max_bound, - min_value=_scale_min, + min_value=scale_min_post, ) ) - scale_pre = self._maybe_quantize_scale( - _amax_to_scale( - _to_local(self.amax_pre), - self._quant_max_bound, - min_value=_scale_min, - ) + scale_pre = _amax_to_scale( + _to_local(self.amax_pre), + self._quant_max_bound, + min_value=scale_min_pre, ) + if self._quantize_pre_scale: + scale_pre = self._maybe_quantize_scale(scale_pre) quant_input = inputs.float() / scale_pre.float().view(-1, 1) w_cast = self._cast_ste(quant_input) return (w_cast * scale_post.view(-1, 1).to(w_cast.dtype)).to(inputs.dtype) diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index 55648ac26e2..21c6a87ec2d 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -27,6 +27,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.opt.dynamic import _pytorch_managed +from modelopt.torch.quantization.nn import StaticBlockScaleQuantizer, TensorQuantizer from modelopt.torch.quantization.utils import ( enable_weight_access_and_writeback, persistent_materialization, @@ -136,6 +137,50 @@ def test_nested_fsdp2_backward(quant_cfg, dist_workers): dist_workers.run(partial(_test_nested_fsdp2_backward, quant_cfg=quant_cfg)) +class _LAQBf16Linear(nn.Module): + """Minimal bf16 module with LAQ learnable amax parameters.""" + + def __init__(self, dim=16): + super().__init__() + self.weight = nn.Parameter(torch.randn(dim, dim, dtype=torch.bfloat16)) + + tq = TensorQuantizer() + tq._num_bits = 4 + tq._unsigned = False + tq._narrow_range = True + tq._disabled = False + tq._block_sizes = {-1: dim} + tq._pass_through_bwd = True + tq.register_buffer("_amax", torch.ones(dim, dtype=torch.bfloat16)) + self.weight_quantizer = StaticBlockScaleQuantizer.from_tensor_quantizer(tq) + self.weight_quantizer.enable_laq( + torch.ones(dim, dtype=torch.bfloat16), + quantize_scales=False, + learnable_amax=["pre", "post"], + ) + + def forward(self, inputs): + weight = self.weight_quantizer._fake_quantize(self.weight) + return torch.nn.functional.linear(inputs, weight) + + +def _test_laq_bf16_learnable_amax_fsdp2(rank, size): + torch.manual_seed(1) + model = _LAQBf16Linear().cuda(rank) + inputs = torch.randn(2, 16, device=rank, dtype=torch.bfloat16) + synchronize_state_dict(model) + + assert {p.dtype for p in model.parameters()} == {torch.bfloat16} + + model = fully_shard(model) + output = model(inputs) + output.float().sum().backward() + + +def test_laq_bf16_learnable_amax_fsdp2(dist_workers): + dist_workers.run(_test_laq_bf16_learnable_amax_fsdp2) + + class _DecoderBlock(nn.Module): """Minimal decoder block for FSDP2 sequential tests.""" diff --git a/tests/gpu/torch/quantization/test_laq_cuda.py b/tests/gpu/torch/quantization/test_laq_cuda.py index 256bcc25dc1..23a04b56a14 100644 --- a/tests/gpu/torch/quantization/test_laq_cuda.py +++ b/tests/gpu/torch/quantization/test_laq_cuda.py @@ -79,6 +79,26 @@ }, } +NVFP4_LAQ_SKIP_PRE_SCALE_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + }, + "algorithm": { + "method": "laq", + "learnable_amax": ["post"], + "quantize_pre_scale": False, + "scale_algorithm": {"method": "mse", "fp8_scale_sweep": True}, + }, +} + class SimpleModel(nn.Module): """Minimal model for LAQ testing.""" @@ -102,8 +122,13 @@ def forward_loop(m): @pytest.mark.parametrize( "config", - [NVFP4_LAQ_POST_MSE_CFG, NVFP4_LAQ_PRE_POST_MSE_CFG, NVFP4_LAQ_TIED_MSE_CFG], - ids=["post_only", "pre_and_post", "tied"], + [ + NVFP4_LAQ_POST_MSE_CFG, + NVFP4_LAQ_PRE_POST_MSE_CFG, + NVFP4_LAQ_TIED_MSE_CFG, + NVFP4_LAQ_SKIP_PRE_SCALE_MSE_CFG, + ], + ids=["post_only", "pre_and_post", "tied", "skip_pre_scale"], ) def test_laq_quantize_e2e(config): """End-to-end: quantize a small model with LAQ + NVFP4 on GPU.""" @@ -112,6 +137,9 @@ def test_laq_quantize_e2e(config): forward_loop = _make_forward_loop(model, device) model = mtq.quantize(model, config, forward_loop=forward_loop) + assert model.linear.weight_quantizer._quantize_pre_scale is config["algorithm"].get( + "quantize_pre_scale", True + ) # Verify the model still produces output of the correct shape x = torch.randn(2, 64, device=device) diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index bf8ba6f09b2..e240c5dd5f2 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -149,6 +149,85 @@ def test_load_state_dict_preserves_fp32_scale_state(self, device): torch.testing.assert_close(target.amax, source.amax) torch.testing.assert_close(target.global_amax, source.global_amax) + def test_laq_dtype_cast_handles_learnable_params_and_frozen_buffers(self, device): + """Casting should update learnable amax params but keep frozen scale buffers fp32.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + expected = torch.tensor( + [1.0001, 1.0002, 1.0003, 1.0004], device=device, dtype=torch.float32 + ) + quantizer.enable_laq( + expected, + per_tensor_scale=torch.tensor(1.0, device=device), + quantize_scales=True, + learnable_amax=["pre", "post"], + ) + + quantizer = quantizer.half() + + assert quantizer._amax_pre.dtype == torch.float16 + assert quantizer._amax_post.dtype == torch.float16 + torch.testing.assert_close(quantizer._amax_pre, expected.to(torch.float16)) + torch.testing.assert_close(quantizer._amax_post, expected.to(torch.float16)) + + quantizer = quantizer.to(dtype=torch.bfloat16) + + assert quantizer._amax_pre.dtype == torch.bfloat16 + assert quantizer._amax_post.dtype == torch.bfloat16 + torch.testing.assert_close(quantizer._amax_pre, expected.to(torch.bfloat16)) + torch.testing.assert_close(quantizer._amax_post, expected.to(torch.bfloat16)) + + frozen_quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + frozen_quantizer.enable_laq( + expected.to(torch.bfloat16), + per_tensor_scale=torch.tensor(1.0, device=device, dtype=torch.bfloat16), + quantize_scales=True, + learnable_amax=[], + ) + frozen_quantizer = frozen_quantizer.half() + + assert frozen_quantizer._amax_pre.dtype == torch.float32 + assert frozen_quantizer._amax_post.dtype == torch.float32 + assert frozen_quantizer._per_tensor_scale.dtype == torch.float32 + expected_frozen = expected.to(torch.bfloat16).float() + torch.testing.assert_close(frozen_quantizer._amax_pre, expected_frozen) + torch.testing.assert_close(frozen_quantizer._amax_post, expected_frozen) + + def test_laq_load_state_dict_preserves_fp32_scale_state(self, device): + """Loading lower-precision state should keep LAQ scale state fp32.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + expected = torch.arange(1, 5, device=device, dtype=torch.float32) + source = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + source.enable_laq( + expected, + per_tensor_scale=torch.tensor(1.0, device=device), + quantize_scales=True, + learnable_amax=["pre", "post"], + ) + state_dict = source.state_dict() + state_dict["_amax_pre"] = state_dict["_amax_pre"].to(dtype=torch.float16) + state_dict["_amax_post"] = state_dict["_amax_post"].to(dtype=torch.float16) + + target = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + target.enable_laq( + torch.zeros(4, device=device), + per_tensor_scale=torch.tensor(0.0, device=device), + quantize_scales=True, + learnable_amax=["pre", "post"], + ) + target.load_state_dict(state_dict) + + assert target._amax_pre.dtype == torch.float32 + assert target._amax_post.dtype == torch.float32 + torch.testing.assert_close(target._amax_pre, source._amax_pre) + torch.testing.assert_close(target._amax_post, source._amax_post) + def test_modelopt_state_restore_uses_fp32_scale_metadata(self, device): """ModelOpt metadata restore should use the saved fp32 static scale metadata.""" cfg = QuantizerAttributeConfig( diff --git a/tests/unit/torch/quantization/test_laq.py b/tests/unit/torch/quantization/test_laq.py index 3adc335690e..1be880681ec 100644 --- a/tests/unit/torch/quantization/test_laq.py +++ b/tests/unit/torch/quantization/test_laq.py @@ -36,6 +36,7 @@ def test_default_config(self): assert cfg.method == "laq" assert cfg.learnable_amax == ["post"] assert cfg.tied_amax is False + assert cfg.quantize_pre_scale is True assert cfg.scale_algorithm is None @pytest.mark.parametrize( @@ -128,6 +129,15 @@ def test_old_amax_deleted(self): q.enable_laq(torch.ones(8), quantize_scales=False) assert not hasattr(q, "_amax") + def test_can_skip_pre_scale_quantization(self): + q = self._make_quantizer() + q.enable_laq( + torch.ones(8), + quantize_scales=False, + quantize_pre_scale=False, + ) + assert q._quantize_pre_scale is False + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_learnable_amax_uses_input_dtype(self, dtype): q = self._make_quantizer() @@ -258,3 +268,44 @@ def test_tied_shares_tensor(self): out = q._fake_quantize(x) out.sum().backward() assert q._amax_post.grad is not None + + def test_skip_pre_scale_quantization_still_quantizes_post(self, monkeypatch): + q = self._make_laq_quantizer() + q._quantize_scales = True + q._quantize_pre_scale = False + q.register_buffer("_per_tensor_scale", torch.tensor(1.0)) + calls = [] + + def spy_maybe_quantize_scale(scale_raw): + calls.append(scale_raw) + return scale_raw + + monkeypatch.setattr(q, "_maybe_quantize_scale", spy_maybe_quantize_scale) + + out = q._fake_quantize(torch.randn(4, 16)) + + assert out.shape == (4, 16) + assert len(calls) == 1 + + def test_skip_pre_scale_quantization_uses_raw_scale_floor(self, monkeypatch): + q = self._make_laq_quantizer() + q._quantize_scales = True + q._quantize_pre_scale = False + q.register_buffer("_per_tensor_scale", torch.tensor(1.0)) + min_values = [] + + def fake_amax_to_scale(amax, maxbound, min_value=None): + min_values.append(min_value) + return torch.ones_like(amax) + + monkeypatch.setattr( + "modelopt.torch.quantization.nn.modules.tensor_quantizer._amax_to_scale", + fake_amax_to_scale, + ) + monkeypatch.setattr(q, "_maybe_quantize_scale", lambda scale_raw: scale_raw) + + out = q._fake_quantize(torch.randn(4, 16)) + + assert out.shape == (4, 16) + assert torch.equal(min_values[0], torch.tensor([0.002])) + assert min_values[1] == 1e-8 diff --git a/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py index dfb776a0484..01bf188497f 100644 --- a/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py +++ b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py @@ -20,7 +20,12 @@ import pytest import torch -from modelopt.torch.export.quant_utils import QUANTIZATION_NVFP4, to_quantized_weight +from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NVFP4, + get_laq_weight_scaling_factors, + to_quantized_weight, +) +from modelopt.torch.export.unified_export_hf import _export_quantized_weight from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import NVFP4StaticQuantizer from modelopt.torch.quantization.qtensor import NVFP4QTensor @@ -70,6 +75,90 @@ def _export_round_trip( return weight_scale, weight_scale_2, dequant +def _make_laq_quantizer( + learnable_amax: list[str], + tied_amax: bool, + quantize_pre_scale: bool, +) -> NVFP4StaticQuantizer: + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: BLOCK_SIZE, "type": "static", "scale_bits": (4, 3)}, + ) + q = NVFP4StaticQuantizer(quant_attribute_cfg=cfg) + q.enable_laq( + torch.ones(8), + per_tensor_scale=torch.tensor(1.0), + quantize_scales=True, + learnable_amax=learnable_amax, + tied_amax=tied_amax, + quantize_pre_scale=quantize_pre_scale, + ) + with torch.no_grad(): + q.amax_post.copy_(torch.full_like(q.amax_post, 3.0)) + if not tied_amax: + q.amax_pre.copy_(torch.full_like(q.amax_pre, 1.5)) + return q + + +@pytest.mark.parametrize( + ("learnable_amax", "tied_amax", "quantize_pre_scale"), + [ + (["post"], False, True), + (["pre"], False, True), + (["pre", "post"], False, True), + (["pre", "post"], True, True), + ([], False, True), + (["post"], False, False), + ], +) +def test_laq_nvfp4_export_uses_pre_scale_for_packing_and_post_scale_for_dequant( + learnable_amax, tied_amax, quantize_pre_scale +): + weight = torch.linspace(-4.0, 4.0, 4 * 32, dtype=torch.float32).view(4, 32) + q = _make_laq_quantizer(learnable_amax, tied_amax, quantize_pre_scale) + + pre_scale, pre_scale_2, post_scale, post_scale_2 = get_laq_weight_scaling_factors( + q, weight, BLOCK_SIZE + ) + packed = to_quantized_weight(weight, pre_scale, QUANTIZATION_NVFP4, pre_scale_2, BLOCK_SIZE) + + dequant = NVFP4QTensor(weight.shape, weight.dtype, packed).dequantize( + scale=post_scale, + double_scale=post_scale_2, + block_sizes={-1: BLOCK_SIZE}, + dtype=weight.dtype, + ) + + assert packed.dtype == torch.uint8 + assert post_scale.shape == (4, 2) + assert torch.isfinite(dequant).all() + if not tied_amax: + wrong_packed = to_quantized_weight( + weight, post_scale, QUANTIZATION_NVFP4, post_scale_2, BLOCK_SIZE + ) + assert not torch.equal(packed, wrong_packed) + + +def test_laq_nvfp4_export_quantized_weight_registers_post_scale_and_packs_with_pre_scale(): + weight = torch.linspace(-4.0, 4.0, 4 * 32, dtype=torch.float32).view(4, 32) + module = torch.nn.Module() + module.weight = torch.nn.Parameter(weight.clone()) + module.weight_quantizer = _make_laq_quantizer(["post"], False, quantize_pre_scale=False) + + pre_scale, pre_scale_2, post_scale, post_scale_2 = get_laq_weight_scaling_factors( + module.weight_quantizer, weight, BLOCK_SIZE + ) + expected_packed = to_quantized_weight( + weight, pre_scale, QUANTIZATION_NVFP4, pre_scale_2, BLOCK_SIZE + ) + + _export_quantized_weight(module, torch.float32) + + torch.testing.assert_close(module.weight, expected_packed) + torch.testing.assert_close(module.weight_scale, post_scale) + torch.testing.assert_close(module.weight_scale_2, post_scale_2.squeeze()) + + def _layer1_routed_expert_like( out_dim: int, in_dim: int, *, n_outliers: int, seed: int ) -> torch.Tensor: