Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,8 +1948,7 @@ def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name):
algo_kwargs = {k: v for k, v in algo_kwargs.items() if k in accepted}
calib_func(model, forward_loop=forward_loop, **algo_kwargs)

if method == "max":
_convert_to_static_block_quantizers(model)
_convert_to_static_block_quantizers(model)


def _compute_block_scales(quantizer):
Expand Down
51 changes: 40 additions & 11 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,31 @@ def _get_bias(self, inputs):
raise ValueError(f"Unsupported bias type: {self.bias_type}")
return bias

def _fake_dynamic_int_block_quantize(self, inputs, block_size, bias):
"""Fake quantize INT blocks with a per-forward max scale."""
if bias is not None:
raise NotImplementedError("Dynamic INT block quantization does not support bias.")

original_last_dim = inputs.shape[-1]
if original_last_dim % block_size != 0:
pad_width = block_size - original_last_dim % block_size
inputs = F.pad(inputs, (0, pad_width), "constant", 0)

blocked_shape = (*inputs.shape[:-1], -1, block_size)
blocked_inputs = inputs.reshape(blocked_shape)
amax = quant_utils.reduce_amax(blocked_inputs, axis=-1, keepdims=True).detach()
outputs = fake_tensor_quant(
blocked_inputs,
amax,
None,
self._num_bits,
self._unsigned,
self._narrow_range,
self._trt_high_precision_dtype,
self._pass_through_bwd,
).reshape(inputs.shape)
return outputs[..., :original_last_dim]

def _is_real_quantize_support(self):
"""Check if real quantization is supported for this quant config."""
return (
Expand Down Expand Up @@ -817,17 +842,21 @@ def _fake_quantize(self, inputs):
if block_size is None:
raise ValueError("block size for dynamic quantization not found.")

outputs = dynamic_block_quant(
inputs,
block_size,
amax,
self._get_bias(inputs),
self._num_bits,
self.block_sizes.get("scale_bits", None),
getattr(self, "_trt_high_precision_dtype", None),
getattr(self, "_onnx_quantizer_type", None),
self._pass_through_bwd,
)
bias = self._get_bias(inputs)
if isinstance(self._num_bits, int) and self.block_sizes.get("scale_bits") is None:
outputs = self._fake_dynamic_int_block_quantize(inputs, block_size, bias)
Comment thread
realAsma marked this conversation as resolved.
Outdated
else:
outputs = dynamic_block_quant(
inputs,
block_size,
amax,
bias,
self._num_bits,
self.block_sizes.get("scale_bits", None),
getattr(self, "_trt_high_precision_dtype", None),
getattr(self, "_onnx_quantizer_type", None),
self._pass_through_bwd,
)
elif isinstance(self._num_bits, tuple):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806
Expand Down
120 changes: 119 additions & 1 deletion tests/unit/torch/quantization/test_laq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""CPU unit tests for the LAQ algorithm using INT4 quantization."""
"""CPU unit tests for the LAQ algorithm using INT block quantization."""

import pytest
import torch
from _test_utils.torch.quantization.models import SimpleLinear
from torch import nn

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.config import LAQConfig
from modelopt.torch.quantization.nn.modules.tensor_quantizer import (
StaticBlockScaleQuantizer,
Expand Down Expand Up @@ -203,3 +205,119 @@ def test_tied_shares_tensor(self):
out = q._fake_quantize(x)
out.sum().backward()
assert q._amax_post.grad is not None


class TestQuantizeLAQIntBlock:
"""End-to-end LAQ tests for low-bit INT static block weight quantization."""

@staticmethod
def _make_config(num_bits, scale_method, learnable_amax=("post",), tied_amax=False):
return {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*weight_quantizer",
"enable": True,
"cfg": {
"num_bits": num_bits,
"block_sizes": {-1: 16, "type": "static"},
},
},
],
"algorithm": {
"method": "laq",
"learnable_amax": list(learnable_amax),
"tied_amax": tied_amax,
"scale_algorithm": {"method": scale_method},
},
}

@staticmethod
def _make_dynamic_max_config(num_bits):
return {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*weight_quantizer",
"enable": True,
"cfg": {
"num_bits": num_bits,
"block_sizes": {-1: 16, "type": "dynamic"},
},
},
],
"algorithm": "max",
}

@staticmethod
def _forward_loop(model):
model(SimpleLinear.get_input())

@staticmethod
def _weight_quantizers(model):
return [
module.weight_quantizer
for module in model.modules()
if hasattr(module, "weight_quantizer")
]

@pytest.mark.parametrize("num_bits", [3, 2])
@pytest.mark.parametrize("scale_method", ["max", "mse"])
def test_low_bit_initializers_enable_laq(self, num_bits, scale_method):
model = mtq.quantize(
SimpleLinear(),
self._make_config(num_bits, scale_method, learnable_amax=["pre", "post"]),
self._forward_loop,
)
weight_quantizers = self._weight_quantizers(model)

assert weight_quantizers
for quantizer in weight_quantizers:
assert isinstance(quantizer, StaticBlockScaleQuantizer)
assert quantizer._laq is True
assert quantizer.num_bits == num_bits
assert quantizer.block_sizes[-1] == 16

output = model(SimpleLinear.get_input())
assert torch.isfinite(output).all()

@pytest.mark.parametrize("num_bits", [3, 2])
@pytest.mark.parametrize(
("learnable_amax", "tied_amax", "expected_learnable_count"),
[
(["pre", "post"], True, 1),
(["pre", "post"], False, 2),
([], False, 0),
],
)
def test_low_bit_laq_variants(
self, num_bits, learnable_amax, tied_amax, expected_learnable_count
):
model = mtq.quantize(
SimpleLinear(),
self._make_config(num_bits, "max", learnable_amax, tied_amax),
self._forward_loop,
)
weight_quantizers = self._weight_quantizers(model)

assert weight_quantizers
for quantizer in weight_quantizers:
learnable_params = [
value
for value in (quantizer.amax_pre, quantizer.amax_post)
if isinstance(value, nn.Parameter)
]
assert len(set(map(id, learnable_params))) == expected_learnable_count
assert quantizer._tied_amax is tied_amax
assert set(quantizer._learnable_amax) == set(learnable_amax)

@pytest.mark.parametrize("num_bits", [3, 2])
def test_low_bit_dynamic_max_weight_only_forward(self, num_bits):
model = mtq.quantize(
SimpleLinear(),
self._make_dynamic_max_config(num_bits),
self._forward_loop,
)
output = model(SimpleLinear.get_input())

assert torch.isfinite(output).all()