From 2613ab9baca265f8d09d844aaff14b336dfe4c9b Mon Sep 17 00:00:00 2001 From: realAsma Date: Mon, 22 Jun 2026 23:10:51 +0000 Subject: [PATCH 1/3] Support INT block scale learning Signed-off-by: realAsma --- modelopt/torch/quantization/model_calib.py | 3 +- .../nn/modules/tensor_quantizer.py | 51 ++++++-- tests/unit/torch/quantization/test_laq.py | 120 +++++++++++++++++- 3 files changed, 160 insertions(+), 14 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9c39e795c92..99f8ffbf858 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1948,8 +1948,7 @@ def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name): algo_kwargs = {k: v for k, v in algo_kwargs.items() if k in accepted} calib_func(model, forward_loop=forward_loop, **algo_kwargs) - if method == "max": - _convert_to_static_block_quantizers(model) + _convert_to_static_block_quantizers(model) def _compute_block_scales(quantizer): diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index caba2e0bc45..26adb58dd06 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -703,6 +703,31 @@ def _get_bias(self, inputs): raise ValueError(f"Unsupported bias type: {self.bias_type}") return bias + def _fake_dynamic_int_block_quantize(self, inputs, block_size, bias): + """Fake quantize INT blocks with a per-forward max scale.""" + if bias is not None: + raise NotImplementedError("Dynamic INT block quantization does not support bias.") + + original_last_dim = inputs.shape[-1] + if original_last_dim % block_size != 0: + pad_width = block_size - original_last_dim % block_size + inputs = F.pad(inputs, (0, pad_width), "constant", 0) + + blocked_shape = (*inputs.shape[:-1], -1, block_size) + blocked_inputs = inputs.reshape(blocked_shape) + amax = quant_utils.reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach() + outputs = fake_tensor_quant( + blocked_inputs, + amax, + None, + self._num_bits, + self._unsigned, + self._narrow_range, + self._trt_high_precision_dtype, + self._pass_through_bwd, + ).reshape(inputs.shape) + return outputs[..., :original_last_dim] + def _is_real_quantize_support(self): """Check if real quantization is supported for this quant config.""" return ( @@ -817,17 +842,21 @@ def _fake_quantize(self, inputs): if block_size is None: raise ValueError("block size for dynamic quantization not found.") - outputs = dynamic_block_quant( - inputs, - block_size, - amax, - self._get_bias(inputs), - self._num_bits, - self.block_sizes.get("scale_bits", None), - getattr(self, "_trt_high_precision_dtype", None), - getattr(self, "_onnx_quantizer_type", None), - self._pass_through_bwd, - ) + bias = self._get_bias(inputs) + if isinstance(self._num_bits, int) and self.block_sizes.get("scale_bits") is None: + outputs = self._fake_dynamic_int_block_quantize(inputs, block_size, bias) + else: + outputs = dynamic_block_quant( + inputs, + block_size, + amax, + bias, + self._num_bits, + self.block_sizes.get("scale_bits", None), + getattr(self, "_trt_high_precision_dtype", None), + getattr(self, "_onnx_quantizer_type", None), + self._pass_through_bwd, + ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 diff --git a/tests/unit/torch/quantization/test_laq.py b/tests/unit/torch/quantization/test_laq.py index e1eed2659f2..e3b6933fd2e 100644 --- a/tests/unit/torch/quantization/test_laq.py +++ b/tests/unit/torch/quantization/test_laq.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""CPU unit tests for the LAQ algorithm using INT4 quantization.""" +"""CPU unit tests for the LAQ algorithm using INT block quantization.""" import pytest import torch +from _test_utils.torch.quantization.models import SimpleLinear from torch import nn +import modelopt.torch.quantization as mtq from modelopt.torch.quantization.config import LAQConfig from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( StaticBlockScaleQuantizer, @@ -203,3 +205,119 @@ def test_tied_shares_tensor(self): out = q._fake_quantize(x) out.sum().backward() assert q._amax_post.grad is not None + + +class TestQuantizeLAQIntBlock: + """End-to-end LAQ tests for low-bit INT static block weight quantization.""" + + @staticmethod + def _make_config(num_bits, scale_method, learnable_amax=("post",), tied_amax=False): + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "enable": True, + "cfg": { + "num_bits": num_bits, + "block_sizes": {-1: 16, "type": "static"}, + }, + }, + ], + "algorithm": { + "method": "laq", + "learnable_amax": list(learnable_amax), + "tied_amax": tied_amax, + "scale_algorithm": {"method": scale_method}, + }, + } + + @staticmethod + def _make_dynamic_max_config(num_bits): + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "enable": True, + "cfg": { + "num_bits": num_bits, + "block_sizes": {-1: 16, "type": "dynamic"}, + }, + }, + ], + "algorithm": "max", + } + + @staticmethod + def _forward_loop(model): + model(SimpleLinear.get_input()) + + @staticmethod + def _weight_quantizers(model): + return [ + module.weight_quantizer + for module in model.modules() + if hasattr(module, "weight_quantizer") + ] + + @pytest.mark.parametrize("num_bits", [3, 2]) + @pytest.mark.parametrize("scale_method", ["max", "mse"]) + def test_low_bit_initializers_enable_laq(self, num_bits, scale_method): + model = mtq.quantize( + SimpleLinear(), + self._make_config(num_bits, scale_method, learnable_amax=["pre", "post"]), + self._forward_loop, + ) + weight_quantizers = self._weight_quantizers(model) + + assert weight_quantizers + for quantizer in weight_quantizers: + assert isinstance(quantizer, StaticBlockScaleQuantizer) + assert quantizer._laq is True + assert quantizer.num_bits == num_bits + assert quantizer.block_sizes[-1] == 16 + + output = model(SimpleLinear.get_input()) + assert torch.isfinite(output).all() + + @pytest.mark.parametrize("num_bits", [3, 2]) + @pytest.mark.parametrize( + ("learnable_amax", "tied_amax", "expected_learnable_count"), + [ + (["pre", "post"], True, 1), + (["pre", "post"], False, 2), + ([], False, 0), + ], + ) + def test_low_bit_laq_variants( + self, num_bits, learnable_amax, tied_amax, expected_learnable_count + ): + model = mtq.quantize( + SimpleLinear(), + self._make_config(num_bits, "max", learnable_amax, tied_amax), + self._forward_loop, + ) + weight_quantizers = self._weight_quantizers(model) + + assert weight_quantizers + for quantizer in weight_quantizers: + learnable_params = [ + value + for value in (quantizer.amax_pre, quantizer.amax_post) + if isinstance(value, nn.Parameter) + ] + assert len(set(map(id, learnable_params))) == expected_learnable_count + assert quantizer._tied_amax is tied_amax + assert set(quantizer._learnable_amax) == set(learnable_amax) + + @pytest.mark.parametrize("num_bits", [3, 2]) + def test_low_bit_dynamic_max_weight_only_forward(self, num_bits): + model = mtq.quantize( + SimpleLinear(), + self._make_dynamic_max_config(num_bits), + self._forward_loop, + ) + output = model(SimpleLinear.get_input()) + + assert torch.isfinite(output).all() From 4f442526a5a35513faec4d06a95fa01cc4c9f6a6 Mon Sep 17 00:00:00 2001 From: realAsma Date: Mon, 22 Jun 2026 23:42:31 +0000 Subject: [PATCH 2/3] Address INT block review feedback Signed-off-by: realAsma --- .../nn/modules/tensor_quantizer.py | 37 ++++++------------- modelopt/torch/quantization/tensor_quant.py | 37 +++++++++++++++++++ 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 26adb58dd06..83063970a91 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -56,6 +56,7 @@ ) from ...tensor_quant import ( dynamic_block_quant, + dynamic_int_block_quant, fake_tensor_quant, fp4_cast_ste, int_cast_ste, @@ -703,31 +704,6 @@ def _get_bias(self, inputs): raise ValueError(f"Unsupported bias type: {self.bias_type}") return bias - def _fake_dynamic_int_block_quantize(self, inputs, block_size, bias): - """Fake quantize INT blocks with a per-forward max scale.""" - if bias is not None: - raise NotImplementedError("Dynamic INT block quantization does not support bias.") - - original_last_dim = inputs.shape[-1] - if original_last_dim % block_size != 0: - pad_width = block_size - original_last_dim % block_size - inputs = F.pad(inputs, (0, pad_width), "constant", 0) - - blocked_shape = (*inputs.shape[:-1], -1, block_size) - blocked_inputs = inputs.reshape(blocked_shape) - amax = quant_utils.reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach() - outputs = fake_tensor_quant( - blocked_inputs, - amax, - None, - self._num_bits, - self._unsigned, - self._narrow_range, - self._trt_high_precision_dtype, - self._pass_through_bwd, - ).reshape(inputs.shape) - return outputs[..., :original_last_dim] - def _is_real_quantize_support(self): """Check if real quantization is supported for this quant config.""" return ( @@ -844,7 +820,16 @@ def _fake_quantize(self, inputs): bias = self._get_bias(inputs) if isinstance(self._num_bits, int) and self.block_sizes.get("scale_bits") is None: - outputs = self._fake_dynamic_int_block_quantize(inputs, block_size, bias) + outputs = dynamic_int_block_quant( + inputs, + block_size, + bias, + self._num_bits, + self._unsigned, + self._narrow_range, + self._trt_high_precision_dtype, + self._pass_through_bwd, + ) else: outputs = dynamic_block_quant( inputs, diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 17809d50d01..44697aeddbe 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -18,6 +18,7 @@ import warnings import torch +import torch.nn.functional as F from torch.autograd import Function from torch.onnx import symbolic_helper @@ -25,6 +26,7 @@ from .config import QuantizerAttributeConfig from .extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx +from .utils import reduce_amax mx_format_map = { (4, 3): "E4M3", @@ -493,6 +495,41 @@ def _dynamic_block_quantize_forward( return outputs +def dynamic_int_block_quant( + inputs, + block_size, + bias, + num_bits, + unsigned, + narrow_range, + trt_high_precision_dtype=None, + pass_through_bwd=True, +): + """Fake quantize INT blocks with a per-forward max scale.""" + if bias is not None: + raise NotImplementedError("Dynamic INT block quantization does not support bias.") + + original_last_dim = inputs.shape[-1] + if original_last_dim % block_size != 0: + pad_width = block_size - original_last_dim % block_size + inputs = F.pad(inputs, (0, pad_width), "constant", 0) + + blocked_shape = (*inputs.shape[:-1], -1, block_size) + blocked_inputs = inputs.reshape(blocked_shape) + amax = reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach() + outputs = fake_tensor_quant( + blocked_inputs, + amax, + None, + num_bits, + unsigned, + narrow_range, + trt_high_precision_dtype, + pass_through_bwd, + ).reshape(inputs.shape) + return outputs[..., :original_last_dim] + + class DynamicBlockQuantizationFunction(Function): """Dynamic block quantization functional.""" From 1a6fa2d34aa92e9e0a33ac3995a52490f038689e Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 23 Jun 2026 01:26:17 +0000 Subject: [PATCH 3/3] Route dynamic INT blocks through tensor_quant Signed-off-by: realAsma --- .../nn/modules/tensor_quantizer.py | 36 ++++++------------ modelopt/torch/quantization/tensor_quant.py | 38 ++++++++++++------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 83063970a91..caba2e0bc45 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -56,7 +56,6 @@ ) from ...tensor_quant import ( dynamic_block_quant, - dynamic_int_block_quant, fake_tensor_quant, fp4_cast_ste, int_cast_ste, @@ -818,30 +817,17 @@ def _fake_quantize(self, inputs): if block_size is None: raise ValueError("block size for dynamic quantization not found.") - bias = self._get_bias(inputs) - if isinstance(self._num_bits, int) and self.block_sizes.get("scale_bits") is None: - outputs = dynamic_int_block_quant( - inputs, - block_size, - bias, - self._num_bits, - self._unsigned, - self._narrow_range, - self._trt_high_precision_dtype, - self._pass_through_bwd, - ) - else: - outputs = dynamic_block_quant( - inputs, - block_size, - amax, - bias, - self._num_bits, - self.block_sizes.get("scale_bits", None), - getattr(self, "_trt_high_precision_dtype", None), - getattr(self, "_onnx_quantizer_type", None), - self._pass_through_bwd, - ) + outputs = dynamic_block_quant( + inputs, + block_size, + amax, + self._get_bias(inputs), + self._num_bits, + self.block_sizes.get("scale_bits", None), + getattr(self, "_trt_high_precision_dtype", None), + getattr(self, "_onnx_quantizer_type", None), + self._pass_through_bwd, + ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 44697aeddbe..308ffe9ef73 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -466,6 +466,7 @@ def _dynamic_block_quantize_forward( inputs, block_size, amax, + bias, num_bits, scale_bits, trt_high_precision_dtype=None, @@ -473,6 +474,17 @@ def _dynamic_block_quantize_forward( pass_through_bwd=True, ): """Forward method.""" + if isinstance(num_bits, int) and scale_bits is None: + return _dynamic_int_block_quantize_forward( + ctx, + inputs, + block_size, + bias, + num_bits, + pass_through_bwd, + ) + + _save_for_backward_if_needed(ctx, pass_through_bwd, inputs, amax) if isinstance(num_bits, int): # special case for INT dynamic block quantization, e.g. MXINT8 exponent_bits = 0 @@ -495,14 +507,12 @@ def _dynamic_block_quantize_forward( return outputs -def dynamic_int_block_quant( +def _dynamic_int_block_quantize_forward( + ctx, inputs, block_size, bias, num_bits, - unsigned, - narrow_range, - trt_high_precision_dtype=None, pass_through_bwd=True, ): """Fake quantize INT blocks with a per-forward max scale.""" @@ -516,16 +526,18 @@ def dynamic_int_block_quant( blocked_shape = (*inputs.shape[:-1], -1, block_size) blocked_inputs = inputs.reshape(blocked_shape) - amax = reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach() - outputs = fake_tensor_quant( + block_amax = reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach() + if not pass_through_bwd: + backward_amax = block_amax.expand_as(blocked_inputs).reshape(inputs.shape) + ctx.save_for_backward( + inputs[..., :original_last_dim], backward_amax[..., :original_last_dim] + ) + outputs = _tensor_quant( blocked_inputs, - amax, - None, + block_amax, num_bits, - unsigned, - narrow_range, - trt_high_precision_dtype, - pass_through_bwd, + unsigned=False, + narrow_range=True, ).reshape(inputs.shape) return outputs[..., :original_last_dim] @@ -585,12 +597,12 @@ def forward( pass_through_bwd=True, ): """Forward method.""" - _save_for_backward_if_needed(ctx, pass_through_bwd, inputs, amax) return _dynamic_block_quantize_forward( ctx, inputs, block_size, amax, + bias, num_bits, scale_bits, trt_high_precision_dtype,