From 71257f9ffa3bd55bb972792afbd2a19d4d6880bc Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 23 Jun 2026 22:05:44 +0000 Subject: [PATCH 1/3] Fix AutoQuantize causal LM score scaling Signed-off-by: realAsma --- examples/llm_autodeploy/run_auto_quantize.py | 15 ++++ examples/llm_eval/quantization_utils.py | 15 ++++ examples/llm_ptq/hf_ptq.py | 24 ++++-- .../quantization/plugins/test_huggingface.py | 76 +++++++++++++++++++ 4 files changed, 123 insertions(+), 7 deletions(-) diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index db35e4841fb..86953beb2f8 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -31,6 +31,17 @@ } +def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Sum-reduced next-token loss keeps AutoQuantize scores additive across batches.""" + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="sum", + ) + + def update_weight_quantizer_amax_for_fusion(model: torch.nn.Module): """Group modules that take the same input and set amax to enable gemm fusion.""" input_to_linear = defaultdict(list) @@ -80,6 +91,10 @@ def auto_quantize( def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. + logits = getattr(output, "logits", None) + labels = data.get("labels") if isinstance(data, dict) else None + if logits is not None and labels is not None: + return _causal_lm_sum_loss(logits, labels) return output.loss model, _ = mtq.auto_quantize( diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 80117bd1627..85f9852ef4b 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -52,6 +52,17 @@ } +def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Sum-reduced next-token loss keeps AutoQuantize scores additive across batches.""" + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="sum", + ) + + def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, trust_remote_code=False): """Returns the tokenizer from the model ckpt_path.""" print(f"Initializing tokenizer from {ckpt_path}") @@ -101,6 +112,10 @@ def forward_step(model, batch): def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. + logits = getattr(output, "logits", None) + labels = data.get("labels") if isinstance(data, dict) else None + if logits is not None and labels is not None: + return _causal_lm_sum_loss(logits, labels) return output.loss elif auto_quantize_method == "kl_div": # For KL divergence method, return only logits diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d27aa593f6..5a8052442d8 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -23,7 +23,6 @@ import numpy as np import torch -from accelerate.hooks import remove_hook_from_module from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( @@ -55,6 +54,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts +from accelerate.hooks import remove_hook_from_module from modelopt.recipe import ModelOptPTQRecipe, load_recipe from modelopt.torch.export import ( export_hf_checkpoint, @@ -288,6 +288,17 @@ def make_calib_dataloader( return calib_dataloader, first_text_speech_dataset +def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Sum-reduced next-token loss keeps AutoQuantize scores additive across batches.""" + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="sum", + ) + + def auto_quantize( args: argparse.Namespace, language_model: torch.nn.Module, @@ -349,16 +360,15 @@ def auto_quantize( def loss_func(output, data): logits = lm_head(output.last_hidden_state) - labels = data["labels"] - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - return torch.nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) + return _causal_lm_sum_loss(logits, data["labels"]) else: def loss_func(output, data): + logits = getattr(output, "logits", None) + labels = data.get("labels") if isinstance(data, dict) else None + if logits is not None and labels is not None: + return _causal_lm_sum_loss(logits, labels) return output.loss if auto_quantize_method == "gradient": diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index ae638c42ee2..2c07f4470c6 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -24,6 +24,7 @@ create_tiny_llama_dir, get_tiny_gpt_oss, get_tiny_llama, + get_tiny_qwen3, get_tiny_qwen3_moe, tf_modelopt_state_and_output_tester, ) @@ -193,6 +194,81 @@ def forward_step(model, batch): ) +def _causal_lm_sum_loss(logits, labels): + logits = logits.float() + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + reduction="sum", + ) + + +def _qwen3_calib_sample(length): + input_ids = (torch.arange(1, length + 1, dtype=torch.long) % 31).unsqueeze(0) + return { + "input_ids": input_ids, + "attention_mask": torch.ones_like(input_ids), + "labels": input_ids.clone(), + } + + +def _qwen3_padded_calib_pair(): + sample_short = torch.arange(1, 17, dtype=torch.long) % 31 + sample_long = torch.arange(1, 33, dtype=torch.long) % 31 + input_ids = torch.zeros((2, 32), dtype=torch.long) + labels = torch.full((2, 32), -100, dtype=torch.long) + attention_mask = torch.zeros((2, 32), dtype=torch.long) + + input_ids[0, :16] = sample_short + labels[0, :16] = sample_short + attention_mask[0, :16] = 1 + input_ids[1] = sample_long + labels[1] = sample_long + attention_mask[1] = 1 + + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + + +def _tiny_qwen3_autoquantize_candidate_stats(data_loader, num_steps): + model = get_tiny_qwen3(dtype=torch.float32, num_hidden_layers=1, max_position_embeddings=64) + + def forward_step(model, batch): + return model(**batch) + + def loss_func(output, batch): + return _causal_lm_sum_loss(output.logits, batch["labels"]) + + _, search_history = mtq.auto_quantize( + model, + constraints={"effective_bits": 11.0}, + quantization_formats=[mtq.INT8_DEFAULT_CFG], + data_loader=data_loader, + forward_step=forward_step, + loss_func=loss_func, + num_calib_steps=num_steps, + num_score_steps=num_steps, + verbose=False, + method="gradient", + ) + return search_history["candidate_stats"] + + +def test_autoquantize_huggingface_scores_are_batch_size_invariant_with_padding(): + stats_bs1 = _tiny_qwen3_autoquantize_candidate_stats( + [_qwen3_calib_sample(16), _qwen3_calib_sample(32)], num_steps=2 + ) + stats_bs2 = _tiny_qwen3_autoquantize_candidate_stats([_qwen3_padded_calib_pair()], num_steps=1) + + assert stats_bs1.keys() == stats_bs2.keys() + for name in stats_bs1: + assert stats_bs1[name]["scores"] == pytest.approx( + stats_bs2[name]["scores"], rel=1e-5, abs=1e-9 + ) + + @pytest.mark.parametrize( ("model_cls", "quant_config"), [ From ae34e303c4daffa598d416c4b6f3da76be488526 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 23 Jun 2026 22:37:06 +0000 Subject: [PATCH 2/3] Refine AutoQuantize batch-size regression test Signed-off-by: realAsma --- .../quantization/plugins/test_huggingface.py | 51 ++++++++++++------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 2c07f4470c6..e1674ea0c3e 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -206,8 +206,12 @@ def _causal_lm_sum_loss(logits, labels): ) -def _qwen3_calib_sample(length): - input_ids = (torch.arange(1, length + 1, dtype=torch.long) % 31).unsqueeze(0) +def _qwen3_sample_ids(length, offset=0): + return (torch.arange(length, dtype=torch.long) + offset).remainder(31) + 1 + + +def _qwen3_calib_sample(length, offset=0): + input_ids = _qwen3_sample_ids(length, offset).unsqueeze(0) return { "input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), @@ -216,8 +220,8 @@ def _qwen3_calib_sample(length): def _qwen3_padded_calib_pair(): - sample_short = torch.arange(1, 17, dtype=torch.long) % 31 - sample_long = torch.arange(1, 33, dtype=torch.long) % 31 + sample_short = _qwen3_sample_ids(16) + sample_long = _qwen3_sample_ids(32, offset=7) input_ids = torch.zeros((2, 32), dtype=torch.long) labels = torch.full((2, 32), -100, dtype=torch.long) attention_mask = torch.zeros((2, 32), dtype=torch.long) @@ -241,31 +245,40 @@ def forward_step(model, batch): def loss_func(output, batch): return _causal_lm_sum_loss(output.logits, batch["labels"]) - _, search_history = mtq.auto_quantize( - model, - constraints={"effective_bits": 11.0}, - quantization_formats=[mtq.INT8_DEFAULT_CFG], - data_loader=data_loader, - forward_step=forward_step, - loss_func=loss_func, - num_calib_steps=num_steps, - num_score_steps=num_steps, - verbose=False, - method="gradient", - ) + with pytest.warns( + UserWarning, + match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ", + ): + _, search_history = mtq.auto_quantize( + model, + constraints={"effective_bits": 11.0}, + quantization_formats=[mtq.INT8_DEFAULT_CFG], + data_loader=data_loader, + forward_step=forward_step, + loss_func=loss_func, + num_calib_steps=num_steps, + num_score_steps=num_steps, + verbose=False, + method="gradient", + ) return search_history["candidate_stats"] def test_autoquantize_huggingface_scores_are_batch_size_invariant_with_padding(): stats_bs1 = _tiny_qwen3_autoquantize_candidate_stats( - [_qwen3_calib_sample(16), _qwen3_calib_sample(32)], num_steps=2 + [_qwen3_calib_sample(16), _qwen3_calib_sample(32, offset=7)], num_steps=2 ) stats_bs2 = _tiny_qwen3_autoquantize_candidate_stats([_qwen3_padded_calib_pair()], num_steps=1) assert stats_bs1.keys() == stats_bs2.keys() for name in stats_bs1: - assert stats_bs1[name]["scores"] == pytest.approx( - stats_bs2[name]["scores"], rel=1e-5, abs=1e-9 + assert stats_bs1[name]["formats"] == stats_bs2[name]["formats"] + torch.testing.assert_close( + torch.tensor(stats_bs1[name]["scores"]), + torch.tensor(stats_bs2[name]["scores"]), + rtol=1e-4, + atol=1e-5, + msg=f"Candidate scores differ for {name}", ) From 2a2163e7c2a6d5c04565ea64db3dcd8a55a52fb3 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 23 Jun 2026 23:02:10 +0000 Subject: [PATCH 3/3] Use HF num_items_in_batch for AQ sum loss Signed-off-by: realAsma --- examples/llm_autodeploy/run_auto_quantize.py | 8 +++++++- examples/llm_eval/quantization_utils.py | 5 ++++- examples/llm_ptq/hf_ptq.py | 10 ++++++++-- .../quantization/plugins/test_huggingface.py | 15 ++++++++++++--- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index 86953beb2f8..15025c32030 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -91,17 +91,23 @@ def auto_quantize( def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. + loss = getattr(output, "loss", None) + if loss is not None: + return loss logits = getattr(output, "logits", None) labels = data.get("labels") if isinstance(data, dict) else None if logits is not None and labels is not None: return _causal_lm_sum_loss(logits, labels) return output.loss + def forward_step(model, batch): + return model(**{**batch, "num_items_in_batch": 1}) + model, _ = mtq.auto_quantize( model, constraints={"effective_bits": auto_quantize_bits}, data_loader=calib_dataloader, - forward_step=lambda model, batch: model(**batch), + forward_step=forward_step, loss_func=loss_func, quantization_formats=[SUPPORT_QUANT_FORMAT[quant_format] for quant_format in qformat_list], num_calib_steps=len(calib_dataloader), diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 85f9852ef4b..d545d8117f1 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -107,11 +107,14 @@ def _quantize_model_with_dataset( if auto_quantize_method == "gradient": # For gradient-based method, return full output with loss def forward_step(model, batch): - return model(**batch) + return model(**{**batch, "num_items_in_batch": 1}) def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. + loss = getattr(output, "loss", None) + if loss is not None: + return loss logits = getattr(output, "logits", None) labels = data.get("labels") if isinstance(data, dict) else None if logits is not None and labels is not None: diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5a8052442d8..bbfaaf553d9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -23,6 +23,7 @@ import numpy as np import torch +from accelerate.hooks import remove_hook_from_module from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( @@ -54,7 +55,6 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts -from accelerate.hooks import remove_hook_from_module from modelopt.recipe import ModelOptPTQRecipe, load_recipe from modelopt.torch.export import ( export_hf_checkpoint, @@ -365,6 +365,9 @@ def loss_func(output, data): else: def loss_func(output, data): + loss = getattr(output, "loss", None) + if loss is not None: + return loss logits = getattr(output, "logits", None) labels = data.get("labels") if isinstance(data, dict) else None if logits is not None and labels is not None: @@ -374,7 +377,10 @@ def loss_func(output, data): if auto_quantize_method == "gradient": def forward_step(model, batch): - inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch + if is_base_model: + inputs = {k: v for k, v in batch.items() if k != "labels"} + else: + inputs = {**batch, "num_items_in_batch": 1} return model(**inputs) elif auto_quantize_method == "kl_div": diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index e1674ea0c3e..b93613e6aec 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -240,10 +240,10 @@ def _tiny_qwen3_autoquantize_candidate_stats(data_loader, num_steps): model = get_tiny_qwen3(dtype=torch.float32, num_hidden_layers=1, max_position_embeddings=64) def forward_step(model, batch): - return model(**batch) + return model(**{**batch, "num_items_in_batch": 1}) - def loss_func(output, batch): - return _causal_lm_sum_loss(output.logits, batch["labels"]) + def loss_func(output, _batch): + return output.loss with pytest.warns( UserWarning, @@ -264,6 +264,15 @@ def loss_func(output, batch): return search_history["candidate_stats"] +def test_autoquantize_huggingface_num_items_in_batch_uses_sum_loss(): + model = get_tiny_qwen3(dtype=torch.float32, num_hidden_layers=1, max_position_embeddings=64) + batch = _qwen3_padded_calib_pair() + + output = model(**{**batch, "num_items_in_batch": 1}) + + torch.testing.assert_close(output.loss, _causal_lm_sum_loss(output.logits, batch["labels"])) + + def test_autoquantize_huggingface_scores_are_batch_size_invariant_with_padding(): stats_bs1 = _tiny_qwen3_autoquantize_candidate_stats( [_qwen3_calib_sample(16), _qwen3_calib_sample(32, offset=7)], num_steps=2