diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9c39e795c92..99f8ffbf858 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1948,8 +1948,7 @@ def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name): algo_kwargs = {k: v for k, v in algo_kwargs.items() if k in accepted} calib_func(model, forward_loop=forward_loop, **algo_kwargs) - if method == "max": - _convert_to_static_block_quantizers(model) + _convert_to_static_block_quantizers(model) def _compute_block_scales(quantizer): diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 17809d50d01..308ffe9ef73 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -18,6 +18,7 @@ import warnings import torch +import torch.nn.functional as F from torch.autograd import Function from torch.onnx import symbolic_helper @@ -25,6 +26,7 @@ from .config import QuantizerAttributeConfig from .extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx +from .utils import reduce_amax mx_format_map = { (4, 3): "E4M3", @@ -464,6 +466,7 @@ def _dynamic_block_quantize_forward( inputs, block_size, amax, + bias, num_bits, scale_bits, trt_high_precision_dtype=None, @@ -471,6 +474,17 @@ def _dynamic_block_quantize_forward( pass_through_bwd=True, ): """Forward method.""" + if isinstance(num_bits, int) and scale_bits is None: + return _dynamic_int_block_quantize_forward( + ctx, + inputs, + block_size, + bias, + num_bits, + pass_through_bwd, + ) + + _save_for_backward_if_needed(ctx, pass_through_bwd, inputs, amax) if isinstance(num_bits, int): # special case for INT dynamic block quantization, e.g. MXINT8 exponent_bits = 0 @@ -493,6 +507,41 @@ def _dynamic_block_quantize_forward( return outputs +def _dynamic_int_block_quantize_forward( + ctx, + inputs, + block_size, + bias, + num_bits, + pass_through_bwd=True, +): + """Fake quantize INT blocks with a per-forward max scale.""" + if bias is not None: + raise NotImplementedError("Dynamic INT block quantization does not support bias.") + + original_last_dim = inputs.shape[-1] + if original_last_dim % block_size != 0: + pad_width = block_size - original_last_dim % block_size + inputs = F.pad(inputs, (0, pad_width), "constant", 0) + + blocked_shape = (*inputs.shape[:-1], -1, block_size) + blocked_inputs = inputs.reshape(blocked_shape) + block_amax = reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach() + if not pass_through_bwd: + backward_amax = block_amax.expand_as(blocked_inputs).reshape(inputs.shape) + ctx.save_for_backward( + inputs[..., :original_last_dim], backward_amax[..., :original_last_dim] + ) + outputs = _tensor_quant( + blocked_inputs, + block_amax, + num_bits, + unsigned=False, + narrow_range=True, + ).reshape(inputs.shape) + return outputs[..., :original_last_dim] + + class DynamicBlockQuantizationFunction(Function): """Dynamic block quantization functional.""" @@ -548,12 +597,12 @@ def forward( pass_through_bwd=True, ): """Forward method.""" - _save_for_backward_if_needed(ctx, pass_through_bwd, inputs, amax) return _dynamic_block_quantize_forward( ctx, inputs, block_size, amax, + bias, num_bits, scale_bits, trt_high_precision_dtype, diff --git a/tests/unit/torch/quantization/test_laq.py b/tests/unit/torch/quantization/test_laq.py index e1eed2659f2..e3b6933fd2e 100644 --- a/tests/unit/torch/quantization/test_laq.py +++ b/tests/unit/torch/quantization/test_laq.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""CPU unit tests for the LAQ algorithm using INT4 quantization.""" +"""CPU unit tests for the LAQ algorithm using INT block quantization.""" import pytest import torch +from _test_utils.torch.quantization.models import SimpleLinear from torch import nn +import modelopt.torch.quantization as mtq from modelopt.torch.quantization.config import LAQConfig from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( StaticBlockScaleQuantizer, @@ -203,3 +205,119 @@ def test_tied_shares_tensor(self): out = q._fake_quantize(x) out.sum().backward() assert q._amax_post.grad is not None + + +class TestQuantizeLAQIntBlock: + """End-to-end LAQ tests for low-bit INT static block weight quantization.""" + + @staticmethod + def _make_config(num_bits, scale_method, learnable_amax=("post",), tied_amax=False): + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "enable": True, + "cfg": { + "num_bits": num_bits, + "block_sizes": {-1: 16, "type": "static"}, + }, + }, + ], + "algorithm": { + "method": "laq", + "learnable_amax": list(learnable_amax), + "tied_amax": tied_amax, + "scale_algorithm": {"method": scale_method}, + }, + } + + @staticmethod + def _make_dynamic_max_config(num_bits): + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "enable": True, + "cfg": { + "num_bits": num_bits, + "block_sizes": {-1: 16, "type": "dynamic"}, + }, + }, + ], + "algorithm": "max", + } + + @staticmethod + def _forward_loop(model): + model(SimpleLinear.get_input()) + + @staticmethod + def _weight_quantizers(model): + return [ + module.weight_quantizer + for module in model.modules() + if hasattr(module, "weight_quantizer") + ] + + @pytest.mark.parametrize("num_bits", [3, 2]) + @pytest.mark.parametrize("scale_method", ["max", "mse"]) + def test_low_bit_initializers_enable_laq(self, num_bits, scale_method): + model = mtq.quantize( + SimpleLinear(), + self._make_config(num_bits, scale_method, learnable_amax=["pre", "post"]), + self._forward_loop, + ) + weight_quantizers = self._weight_quantizers(model) + + assert weight_quantizers + for quantizer in weight_quantizers: + assert isinstance(quantizer, StaticBlockScaleQuantizer) + assert quantizer._laq is True + assert quantizer.num_bits == num_bits + assert quantizer.block_sizes[-1] == 16 + + output = model(SimpleLinear.get_input()) + assert torch.isfinite(output).all() + + @pytest.mark.parametrize("num_bits", [3, 2]) + @pytest.mark.parametrize( + ("learnable_amax", "tied_amax", "expected_learnable_count"), + [ + (["pre", "post"], True, 1), + (["pre", "post"], False, 2), + ([], False, 0), + ], + ) + def test_low_bit_laq_variants( + self, num_bits, learnable_amax, tied_amax, expected_learnable_count + ): + model = mtq.quantize( + SimpleLinear(), + self._make_config(num_bits, "max", learnable_amax, tied_amax), + self._forward_loop, + ) + weight_quantizers = self._weight_quantizers(model) + + assert weight_quantizers + for quantizer in weight_quantizers: + learnable_params = [ + value + for value in (quantizer.amax_pre, quantizer.amax_post) + if isinstance(value, nn.Parameter) + ] + assert len(set(map(id, learnable_params))) == expected_learnable_count + assert quantizer._tied_amax is tied_amax + assert set(quantizer._learnable_amax) == set(learnable_amax) + + @pytest.mark.parametrize("num_bits", [3, 2]) + def test_low_bit_dynamic_max_weight_only_forward(self, num_bits): + model = mtq.quantize( + SimpleLinear(), + self._make_dynamic_max_config(num_bits), + self._forward_loop, + ) + output = model(SimpleLinear.get_input()) + + assert torch.isfinite(output).all()