Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/llm_autodeploy/run_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
}


def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Sum-reduced next-token loss keeps AutoQuantize scores additive across batches."""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="sum",
)


def update_weight_quantizer_amax_for_fusion(model: torch.nn.Module):
"""Group modules that take the same input and set amax to enable gemm fusion."""
input_to_linear = defaultdict(list)
Expand Down Expand Up @@ -80,6 +91,10 @@ def auto_quantize(
def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
logits = getattr(output, "logits", None)
labels = data.get("labels") if isinstance(data, dict) else None
if logits is not None and labels is not None:
return _causal_lm_sum_loss(logits, labels)
return output.loss

model, _ = mtq.auto_quantize(
Expand Down
15 changes: 15 additions & 0 deletions examples/llm_eval/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@
}


def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Sum-reduced next-token loss keeps AutoQuantize scores additive across batches."""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="sum",
)


def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, trust_remote_code=False):
"""Returns the tokenizer from the model ckpt_path."""
print(f"Initializing tokenizer from {ckpt_path}")
Expand Down Expand Up @@ -101,6 +112,10 @@ def forward_step(model, batch):
def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
logits = getattr(output, "logits", None)
labels = data.get("labels") if isinstance(data, dict) else None
if logits is not None and labels is not None:
return _causal_lm_sum_loss(logits, labels)
return output.loss
elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits
Expand Down
24 changes: 17 additions & 7 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import torch
from accelerate.hooks import remove_hook_from_module
from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4
from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static
from example_utils import (
Expand Down Expand Up @@ -55,6 +54,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from accelerate.hooks import remove_hook_from_module
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from modelopt.torch.export import (
export_hf_checkpoint,
Expand Down Expand Up @@ -288,6 +288,17 @@ def make_calib_dataloader(
return calib_dataloader, first_text_speech_dataset


def _causal_lm_sum_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Sum-reduced next-token loss keeps AutoQuantize scores additive across batches."""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="sum",
)


Comment thread
realAsma marked this conversation as resolved.
def auto_quantize(
args: argparse.Namespace,
language_model: torch.nn.Module,
Expand Down Expand Up @@ -349,16 +360,15 @@ def auto_quantize(

def loss_func(output, data):
logits = lm_head(output.last_hidden_state)
labels = data["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
return _causal_lm_sum_loss(logits, data["labels"])

else:

def loss_func(output, data):
logits = getattr(output, "logits", None)
labels = data.get("labels") if isinstance(data, dict) else None
if logits is not None and labels is not None:
return _causal_lm_sum_loss(logits, labels)
return output.loss

if auto_quantize_method == "gradient":
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/torch/quantization/plugins/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
create_tiny_llama_dir,
get_tiny_gpt_oss,
get_tiny_llama,
get_tiny_qwen3,
get_tiny_qwen3_moe,
tf_modelopt_state_and_output_tester,
)
Expand Down Expand Up @@ -193,6 +194,81 @@ def forward_step(model, batch):
)


def _causal_lm_sum_loss(logits, labels):
logits = logits.float()
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
reduction="sum",
)


def _qwen3_calib_sample(length):
input_ids = (torch.arange(1, length + 1, dtype=torch.long) % 31).unsqueeze(0)
return {
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids),
"labels": input_ids.clone(),
}


def _qwen3_padded_calib_pair():
sample_short = torch.arange(1, 17, dtype=torch.long) % 31
sample_long = torch.arange(1, 33, dtype=torch.long) % 31
input_ids = torch.zeros((2, 32), dtype=torch.long)
labels = torch.full((2, 32), -100, dtype=torch.long)
attention_mask = torch.zeros((2, 32), dtype=torch.long)

input_ids[0, :16] = sample_short
labels[0, :16] = sample_short
attention_mask[0, :16] = 1
input_ids[1] = sample_long
labels[1] = sample_long
attention_mask[1] = 1

return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


def _tiny_qwen3_autoquantize_candidate_stats(data_loader, num_steps):
model = get_tiny_qwen3(dtype=torch.float32, num_hidden_layers=1, max_position_embeddings=64)

def forward_step(model, batch):
return model(**batch)

def loss_func(output, batch):
return _causal_lm_sum_loss(output.logits, batch["labels"])

_, search_history = mtq.auto_quantize(
model,
constraints={"effective_bits": 11.0},
quantization_formats=[mtq.INT8_DEFAULT_CFG],
data_loader=data_loader,
forward_step=forward_step,
loss_func=loss_func,
num_calib_steps=num_steps,
num_score_steps=num_steps,
verbose=False,
method="gradient",
)
return search_history["candidate_stats"]


def test_autoquantize_huggingface_scores_are_batch_size_invariant_with_padding():
stats_bs1 = _tiny_qwen3_autoquantize_candidate_stats(
[_qwen3_calib_sample(16), _qwen3_calib_sample(32)], num_steps=2
)
stats_bs2 = _tiny_qwen3_autoquantize_candidate_stats([_qwen3_padded_calib_pair()], num_steps=1)

assert stats_bs1.keys() == stats_bs2.keys()
for name in stats_bs1:
assert stats_bs1[name]["scores"] == pytest.approx(
stats_bs2[name]["scores"], rel=1e-5, abs=1e-9
)


@pytest.mark.parametrize(
("model_cls", "quant_config"),
[
Expand Down
Loading