diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index db35e4841fb..15025c32030 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -31,6 +31,17 @@ } +def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Sum-reduced next-token loss keeps AutoQuantize scores additive across batches.""" + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="sum", + ) + + def update_weight_quantizer_amax_for_fusion(model: torch.nn.Module): """Group modules that take the same input and set amax to enable gemm fusion.""" input_to_linear = defaultdict(list) @@ -80,13 +91,23 @@ def auto_quantize( def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. + loss = getattr(output, "loss", None) + if loss is not None: + return loss + logits = getattr(output, "logits", None) + labels = data.get("labels") if isinstance(data, dict) else None + if logits is not None and labels is not None: + return _causal_lm_sum_loss(logits, labels) return output.loss + def forward_step(model, batch): + return model(**{**batch, "num_items_in_batch": 1}) + model, _ = mtq.auto_quantize( model, constraints={"effective_bits": auto_quantize_bits}, data_loader=calib_dataloader, - forward_step=lambda model, batch: model(**batch), + forward_step=forward_step, loss_func=loss_func, quantization_formats=[SUPPORT_QUANT_FORMAT[quant_format] for quant_format in qformat_list], num_calib_steps=len(calib_dataloader), diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 80117bd1627..d545d8117f1 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -52,6 +52,17 @@ } +def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Sum-reduced next-token loss keeps AutoQuantize scores additive across batches.""" + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="sum", + ) + + def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, trust_remote_code=False): """Returns the tokenizer from the model ckpt_path.""" print(f"Initializing tokenizer from {ckpt_path}") @@ -96,11 +107,18 @@ def _quantize_model_with_dataset( if auto_quantize_method == "gradient": # For gradient-based method, return full output with loss def forward_step(model, batch): - return model(**batch) + return model(**{**batch, "num_items_in_batch": 1}) def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. + loss = getattr(output, "loss", None) + if loss is not None: + return loss + logits = getattr(output, "logits", None) + labels = data.get("labels") if isinstance(data, dict) else None + if logits is not None and labels is not None: + return _causal_lm_sum_loss(logits, labels) return output.loss elif auto_quantize_method == "kl_div": # For KL divergence method, return only logits diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d27aa593f6..bbfaaf553d9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -288,6 +288,17 @@ def make_calib_dataloader( return calib_dataloader, first_text_speech_dataset +def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Sum-reduced next-token loss keeps AutoQuantize scores additive across batches.""" + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + reduction="sum", + ) + + def auto_quantize( args: argparse.Namespace, language_model: torch.nn.Module, @@ -349,22 +360,27 @@ def auto_quantize( def loss_func(output, data): logits = lm_head(output.last_hidden_state) - labels = data["labels"] - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - return torch.nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) + return _causal_lm_sum_loss(logits, data["labels"]) else: def loss_func(output, data): + loss = getattr(output, "loss", None) + if loss is not None: + return loss + logits = getattr(output, "logits", None) + labels = data.get("labels") if isinstance(data, dict) else None + if logits is not None and labels is not None: + return _causal_lm_sum_loss(logits, labels) return output.loss if auto_quantize_method == "gradient": def forward_step(model, batch): - inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch + if is_base_model: + inputs = {k: v for k, v in batch.items() if k != "labels"} + else: + inputs = {**batch, "num_items_in_batch": 1} return model(**inputs) elif auto_quantize_method == "kl_div": diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index ae638c42ee2..b93613e6aec 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -24,6 +24,7 @@ create_tiny_llama_dir, get_tiny_gpt_oss, get_tiny_llama, + get_tiny_qwen3, get_tiny_qwen3_moe, tf_modelopt_state_and_output_tester, ) @@ -193,6 +194,103 @@ def forward_step(model, batch): ) +def _causal_lm_sum_loss(logits, labels): + logits = logits.float() + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + reduction="sum", + ) + + +def _qwen3_sample_ids(length, offset=0): + return (torch.arange(length, dtype=torch.long) + offset).remainder(31) + 1 + + +def _qwen3_calib_sample(length, offset=0): + input_ids = _qwen3_sample_ids(length, offset).unsqueeze(0) + return { + "input_ids": input_ids, + "attention_mask": torch.ones_like(input_ids), + "labels": input_ids.clone(), + } + + +def _qwen3_padded_calib_pair(): + sample_short = _qwen3_sample_ids(16) + sample_long = _qwen3_sample_ids(32, offset=7) + input_ids = torch.zeros((2, 32), dtype=torch.long) + labels = torch.full((2, 32), -100, dtype=torch.long) + attention_mask = torch.zeros((2, 32), dtype=torch.long) + + input_ids[0, :16] = sample_short + labels[0, :16] = sample_short + attention_mask[0, :16] = 1 + input_ids[1] = sample_long + labels[1] = sample_long + attention_mask[1] = 1 + + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + + +def _tiny_qwen3_autoquantize_candidate_stats(data_loader, num_steps): + model = get_tiny_qwen3(dtype=torch.float32, num_hidden_layers=1, max_position_embeddings=64) + + def forward_step(model, batch): + return model(**{**batch, "num_items_in_batch": 1}) + + def loss_func(output, _batch): + return output.loss + + with pytest.warns( + UserWarning, + match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ", + ): + _, search_history = mtq.auto_quantize( + model, + constraints={"effective_bits": 11.0}, + quantization_formats=[mtq.INT8_DEFAULT_CFG], + data_loader=data_loader, + forward_step=forward_step, + loss_func=loss_func, + num_calib_steps=num_steps, + num_score_steps=num_steps, + verbose=False, + method="gradient", + ) + return search_history["candidate_stats"] + + +def test_autoquantize_huggingface_num_items_in_batch_uses_sum_loss(): + model = get_tiny_qwen3(dtype=torch.float32, num_hidden_layers=1, max_position_embeddings=64) + batch = _qwen3_padded_calib_pair() + + output = model(**{**batch, "num_items_in_batch": 1}) + + torch.testing.assert_close(output.loss, _causal_lm_sum_loss(output.logits, batch["labels"])) + + +def test_autoquantize_huggingface_scores_are_batch_size_invariant_with_padding(): + stats_bs1 = _tiny_qwen3_autoquantize_candidate_stats( + [_qwen3_calib_sample(16), _qwen3_calib_sample(32, offset=7)], num_steps=2 + ) + stats_bs2 = _tiny_qwen3_autoquantize_candidate_stats([_qwen3_padded_calib_pair()], num_steps=1) + + assert stats_bs1.keys() == stats_bs2.keys() + for name in stats_bs1: + assert stats_bs1[name]["formats"] == stats_bs2[name]["formats"] + torch.testing.assert_close( + torch.tensor(stats_bs1[name]["scores"]), + torch.tensor(stats_bs2[name]["scores"]), + rtol=1e-4, + atol=1e-5, + msg=f"Candidate scores differ for {name}", + ) + + @pytest.mark.parametrize( ("model_cls", "quant_config"), [