Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,77 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
return get_scaling_factor(weight_quantizer)


def _to_export_tensor(tensor: torch.Tensor) -> torch.Tensor:
if hasattr(tensor, "to_local"):
tensor = tensor.to_local()
return tensor.detach().float()


def _laq_amax_to_scale(
amax: torch.Tensor, max_bound: float, min_value: float | torch.Tensor
) -> torch.Tensor:
scale = _to_export_tensor(amax) / max_bound
return torch.where(scale <= min_value, min_value, scale)


def _laq_scale_factors(
scale: torch.Tensor,
quantizer: TensorQuantizer,
quantize_scale: bool,
expected_shape: torch.Size,
) -> tuple[torch.Tensor, torch.Tensor]:
scale = scale.view(expected_shape)
if not quantize_scale:
return scale, torch.tensor(1.0, device=scale.device)

scale_amax = _to_export_tensor(quantizer._per_tensor_scale)
scale_2 = scale_amax / 448.0
scaled = (scale * 448.0 / scale_amax.view(-1)).clamp(min=2**-9, max=448.0)
return scaled.to(torch.float8_e4m3fn), scale_2


def get_laq_weight_scaling_factors(
weight_quantizer: TensorQuantizer, weight: torch.Tensor, block_size: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return LAQ pre/post scale factors for NVFP4 export.

The pre factors are used only to pack the FP4 weight codes. The post factors
are serialized as the exported dequantization scales.
"""
assert getattr(weight_quantizer, "_laq", False), "Expected an LAQ quantizer."
assert weight.shape[-1] % block_size == 0, (
"Weight shape is not divisible for block size for block quantization."
)

expected_shape = torch.Size((*weight.shape[:-1], weight.shape[-1] // block_size))
quantize_scales = getattr(weight_quantizer, "_quantize_scales", False)
per_tensor_scale = (
_to_export_tensor(weight_quantizer._per_tensor_scale) if quantize_scales else None
)
max_bound = float(weight_quantizer._quant_max_bound)

post_min = 0.002 * per_tensor_scale.view(-1) if per_tensor_scale is not None else 1e-8
pre_min = (
0.002 * per_tensor_scale.view(-1)
if per_tensor_scale is not None and getattr(weight_quantizer, "_quantize_pre_scale", True)
else 1e-8
)

post_scale = _laq_amax_to_scale(weight_quantizer.amax_post, max_bound, post_min)
pre_scale = _laq_amax_to_scale(weight_quantizer.amax_pre, max_bound, pre_min)

pre_scale, pre_scale_2 = _laq_scale_factors(
pre_scale,
weight_quantizer,
quantize_scales and getattr(weight_quantizer, "_quantize_pre_scale", True),
expected_shape,
)
post_scale, post_scale_2 = _laq_scale_factors(
post_scale, weight_quantizer, quantize_scales, expected_shape
)
return pre_scale, pre_scale_2, post_scale, post_scale_2


def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") -> torch.Tensor:
"""Returns the secondary weight scaling factor."""
weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None)
Expand Down
43 changes: 34 additions & 9 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
fuse_prequant_layernorm,
fuse_prequant_to_linear,
get_activation_scaling_factor,
get_laq_weight_scaling_factors,
get_quant_config,
get_quantization_format,
get_weight_block_size,
Expand Down Expand Up @@ -537,6 +538,12 @@ def _export_quantized_weight(
output_quantizer: TensorQuantizer | SequentialQuantizer | None = getattr(
sub_module, quantizer_attrs.output_quantizer, None
)
is_laq_nvfp4 = getattr(weight_quantizer, "_laq", False) and quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A16_NVFP4,
]

if quantization_format == QUANTIZATION_FP8:
# Convert amax to float32
Expand Down Expand Up @@ -587,6 +594,8 @@ def _export_quantized_weight(
sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale)
if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
del weight_quantizer._scale
elif is_laq_nvfp4:
pass
else:
sub_module.register_buffer(
quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name)
Expand All @@ -604,14 +613,18 @@ def _export_quantized_weight(
).squeeze(),
)

if quantization_format in [
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
if (
quantization_format
in [
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
]
and not is_laq_nvfp4
):
# Register weight_scale_2
sub_module.register_buffer(
quantizer_attrs.weight_scale_2,
Expand Down Expand Up @@ -646,7 +659,11 @@ def _export_quantized_weight(
weight, is_bmm_expert_weight=is_bmm_expert_weight
)

if NVFP4QTensor._is_static_quantizer(weight_quantizer):
if is_laq_nvfp4:
weight_scale, weight_scale_2, export_weight_scale, export_weight_scale_2 = (
get_laq_weight_scaling_factors(weight_quantizer, weight, block_size)
)
elif NVFP4QTensor._is_static_quantizer(weight_quantizer):
weight_scale = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(
weight_quantizer,
weight,
Expand All @@ -666,10 +683,18 @@ def _export_quantized_weight(
weight_scale_2,
block_size,
)
if is_laq_nvfp4:
weight_scale = export_weight_scale
weight_scale_2 = export_weight_scale_2

quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
)
if is_laq_nvfp4:
assert weight_scale is not None
assert weight_scale_2 is not None
sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale)
sub_module.register_buffer(quantizer_attrs.weight_scale_2, weight_scale_2.squeeze())
elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight:
# For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale)
weight, _ = maybe_transpose_expert_weight_dimensions(
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,9 @@ class LAQConfig(QuantizeAlgorithmConfig):
``tied_amax`` makes pre and post share a single tensor (requires both to
have the same learnable state, i.e. ``learnable_amax`` must be
``["pre", "post"]`` or ``[]``).

``quantize_pre_scale=False`` leaves the pre-quantization scale unquantized
while preserving the existing post-scale quantization behavior.
"""

method: Literal["laq"] = ModeloptField("laq")
Expand All @@ -1035,6 +1038,15 @@ class LAQConfig(QuantizeAlgorithmConfig):
),
)

quantize_pre_scale: bool = ModeloptField(
default=True,
title="FP8-quantize the LAQ pre-quantization scale.",
description=(
"If False, LAQ uses the raw pre-quantization scale while keeping post-scale "
"quantization controlled by the quantizer's block-scale settings."
),
)

scale_algorithm: dict | None = ModeloptField(
default=None,
title="Scale calibration algorithm to run first.",
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,7 @@ def laq(
scale_algorithm: dict | None = None,
learnable_amax: list | str = ("post",),
tied_amax: bool = False,
quantize_pre_scale: bool = True,
**kwargs,
):
"""Run scale calibration then convert to LAQ mode.
Expand All @@ -2032,6 +2033,7 @@ def laq(
learnable_amax: Which amax params are learnable: 'pre', 'post',
['pre', 'post'], or [].
tied_amax: If True, pre and post share a single tensor.
quantize_pre_scale: If False, skip FP8 quantization for the LAQ pre scale.
"""
_run_scale_calibration(model, forward_loop, scale_algorithm, "laq")

Expand All @@ -2047,4 +2049,5 @@ def laq(
quantize_scales,
learnable_amax=learnable_amax,
tied_amax=tied_amax,
quantize_pre_scale=quantize_pre_scale,
)
27 changes: 19 additions & 8 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,7 @@ class StaticBlockScaleQuantizer(TensorQuantizer):
_tied_amax: bool = False
_quant_max_bound: float = 6.0
_quantize_scales: bool = True
_quantize_pre_scale: bool = True

def _preserve_amax_in_fp32(self):
for name in ("_amax", "_global_amax", "_per_tensor_scale"):
Expand Down Expand Up @@ -1536,6 +1537,7 @@ def enable_laq(
quantize_scales: bool = True,
learnable_amax: list | str = ("post",),
tied_amax: bool = False,
quantize_pre_scale: bool = True,
):
"""LAQ mode with configurable learnable/frozen amax tensors.

Expand All @@ -1546,6 +1548,7 @@ def enable_laq(
learnable_amax: Which amax params are learnable: 'pre', 'post',
['pre', 'post'], or [].
tied_amax: If True, pre and post share a single tensor.
quantize_pre_scale: Whether to FP8-quantize the LAQ pre scale.
"""
if hasattr(self, "_amax"):
delattr(self, "_amax")
Expand All @@ -1572,6 +1575,7 @@ def enable_laq(
if per_tensor_scale is not None:
self.register_buffer("_per_tensor_scale", per_tensor_scale.float().clone().detach())
self._quantize_scales = quantize_scales
self._quantize_pre_scale = quantize_pre_scale
self._laq = True
self._learnable_amax = sorted(learn)
self._tied_amax = tied_amax
Expand All @@ -1592,22 +1596,29 @@ def _fake_quantize(self, inputs):
"""Fake quantization using two-level scaling with _amax and _global_amax."""
if self._laq:
# 0.002 ≈ smallest positive FP8 E4M3 value; clamps per-block scale floor
_scale_min = 0.002 * self._per_tensor_scale.view(-1) if self._quantize_scales else 1e-8
scale_min_post = (
0.002 * self._per_tensor_scale.view(-1) if self._quantize_scales else 1e-8
)
scale_min_pre = (
0.002 * self._per_tensor_scale.view(-1)
if self._quantize_scales and self._quantize_pre_scale
else 1e-8
)

scale_post = self._maybe_quantize_scale(
_amax_to_scale(
_to_local(self.amax_post),
self._quant_max_bound,
min_value=_scale_min,
min_value=scale_min_post,
)
)
scale_pre = self._maybe_quantize_scale(
_amax_to_scale(
_to_local(self.amax_pre),
self._quant_max_bound,
min_value=_scale_min,
)
scale_pre = _amax_to_scale(
_to_local(self.amax_pre),
self._quant_max_bound,
min_value=scale_min_pre,
)
if self._quantize_pre_scale:
scale_pre = self._maybe_quantize_scale(scale_pre)
quant_input = inputs.float() / scale_pre.float().view(-1, 1)
w_cast = self._cast_ste(quant_input)
return (w_cast * scale_post.view(-1, 1).to(w_cast.dtype)).to(inputs.dtype)
Expand Down
45 changes: 45 additions & 0 deletions tests/gpu/torch/quantization/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import modelopt.torch.quantization as mtq
from modelopt.torch.opt.dynamic import _pytorch_managed
from modelopt.torch.quantization.nn import StaticBlockScaleQuantizer, TensorQuantizer
from modelopt.torch.quantization.utils import (
enable_weight_access_and_writeback,
persistent_materialization,
Expand Down Expand Up @@ -136,6 +137,50 @@ def test_nested_fsdp2_backward(quant_cfg, dist_workers):
dist_workers.run(partial(_test_nested_fsdp2_backward, quant_cfg=quant_cfg))


class _LAQBf16Linear(nn.Module):
"""Minimal bf16 module with LAQ learnable amax parameters."""

def __init__(self, dim=16):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim, dim, dtype=torch.bfloat16))

tq = TensorQuantizer()
tq._num_bits = 4
tq._unsigned = False
tq._narrow_range = True
tq._disabled = False
tq._block_sizes = {-1: dim}
tq._pass_through_bwd = True
tq.register_buffer("_amax", torch.ones(dim, dtype=torch.bfloat16))
self.weight_quantizer = StaticBlockScaleQuantizer.from_tensor_quantizer(tq)
self.weight_quantizer.enable_laq(
torch.ones(dim, dtype=torch.bfloat16),
quantize_scales=False,
learnable_amax=["pre", "post"],
)

def forward(self, inputs):
weight = self.weight_quantizer._fake_quantize(self.weight)
return torch.nn.functional.linear(inputs, weight)


def _test_laq_bf16_learnable_amax_fsdp2(rank, size):
torch.manual_seed(1)
model = _LAQBf16Linear().cuda(rank)
inputs = torch.randn(2, 16, device=rank, dtype=torch.bfloat16)
synchronize_state_dict(model)

assert {p.dtype for p in model.parameters()} == {torch.bfloat16}

model = fully_shard(model)
output = model(inputs)
output.float().sum().backward()


def test_laq_bf16_learnable_amax_fsdp2(dist_workers):
dist_workers.run(_test_laq_bf16_learnable_amax_fsdp2)


class _DecoderBlock(nn.Module):
"""Minimal decoder block for FSDP2 sequential tests."""

Expand Down
32 changes: 30 additions & 2 deletions tests/gpu/torch/quantization/test_laq_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@
},
}

NVFP4_LAQ_SKIP_PRE_SCALE_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
},
"algorithm": {
"method": "laq",
"learnable_amax": ["post"],
"quantize_pre_scale": False,
"scale_algorithm": {"method": "mse", "fp8_scale_sweep": True},
},
}


class SimpleModel(nn.Module):
"""Minimal model for LAQ testing."""
Expand All @@ -102,8 +122,13 @@ def forward_loop(m):

@pytest.mark.parametrize(
"config",
[NVFP4_LAQ_POST_MSE_CFG, NVFP4_LAQ_PRE_POST_MSE_CFG, NVFP4_LAQ_TIED_MSE_CFG],
ids=["post_only", "pre_and_post", "tied"],
[
NVFP4_LAQ_POST_MSE_CFG,
NVFP4_LAQ_PRE_POST_MSE_CFG,
NVFP4_LAQ_TIED_MSE_CFG,
NVFP4_LAQ_SKIP_PRE_SCALE_MSE_CFG,
],
ids=["post_only", "pre_and_post", "tied", "skip_pre_scale"],
)
def test_laq_quantize_e2e(config):
"""End-to-end: quantize a small model with LAQ + NVFP4 on GPU."""
Expand All @@ -112,6 +137,9 @@ def test_laq_quantize_e2e(config):
forward_loop = _make_forward_loop(model, device)

model = mtq.quantize(model, config, forward_loop=forward_loop)
assert model.linear.weight_quantizer._quantize_pre_scale is config["algorithm"].get(
"quantize_pre_scale", True
)

# Verify the model still produces output of the correct shape
x = torch.randn(2, 64, device=device)
Expand Down
Loading