From 72ee53ff1ab5679c5ad1664fe32d391f790a59cf Mon Sep 17 00:00:00 2001 From: realAsma Date: Mon, 29 Jun 2026 22:25:46 +0000 Subject: [PATCH 1/9] Fix HF PTQ empty-init dtype fallback Signed-off-by: realAsma --- examples/hf_ptq/example_utils.py | 55 ++++++++++-- tests/examples/hf_ptq/test_example_utils.py | 93 +++++++++++++++++++++ 2 files changed, 143 insertions(+), 5 deletions(-) diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index c7a4d7a3b9a..8c51160c38f 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -598,6 +598,51 @@ def get_original_hf_quant_method(config) -> str | None: return None +def _empty_model_init_kwargs(model_kwargs, torch_dtype): + init_kwargs = model_kwargs.copy() + init_kwargs.pop("max_memory", None) + init_kwargs["dtype"] = torch_dtype + return init_kwargs + + +def _empty_model_init_dtype(hf_config): + torch_dtype = ( + getattr(hf_config, "dtype", None) + or getattr(hf_config, "torch_dtype", None) + or torch.bfloat16 + ) + if isinstance(torch_dtype, str): + torch_dtype = getattr(torch, torch_dtype) + return torch_dtype + + +def _is_unexpected_dtype_kwarg_error(error): + message = str(error) + return "unexpected keyword argument" in message and ( + "'dtype'" in message or '"dtype"' in message + ) + + +def _from_config_for_empty_weights(from_config, hf_config, model_kwargs, torch_dtype): + try: + return from_config(hf_config, **model_kwargs) + except TypeError as error: + if "dtype" not in model_kwargs or not _is_unexpected_dtype_kwarg_error(error): + raise + + model_kwargs_without_dtype = model_kwargs.copy() + model_kwargs_without_dtype.pop("dtype", None) + orig_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch_dtype) + try: + # Some remote-code constructors, such as DeciLMForCausalLM, reject the + # Transformers dtype kwarg even though bf16-sized empty weights are + # still needed for device-map inference. + return from_config(hf_config, **model_kwargs_without_dtype) + finally: + torch.set_default_dtype(orig_dtype) + + def get_model( ckpt_path, device="cuda", @@ -734,13 +779,13 @@ def has_pack_quantized_config(config): with init_empty_weights(include_buffers=True): # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) - model_kwargs2 = model_kwargs.copy() + torch_dtype = _empty_model_init_dtype(hf_config) + model_kwargs2 = _empty_model_init_kwargs(model_kwargs, torch_dtype) if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) - model_kwargs2["dtype"] = torch_dtype - model_kwargs2.pop("max_memory", None) - model = from_config(hf_config, **model_kwargs2) + model = _from_config_for_empty_weights( + from_config, hf_config, model_kwargs2, torch_dtype + ) max_memory = get_max_memory() inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index d25da6e0ab2..a389b946b23 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -21,6 +21,7 @@ import json from types import SimpleNamespace +import pytest import torch from _test_utils.examples.hf_ptq_example_utils import example_utils from safetensors.torch import save_file @@ -194,3 +195,95 @@ def test_get_original_hf_quant_method_none_for_unquantized(): example_utils.get_original_hf_quant_method(SimpleNamespace(quantization_config=None)) is None ) + + +def test_empty_model_init_kwargs_keeps_dtype_for_general_from_config(): + kwargs = { + "dtype": "auto", + "torch_dtype": torch.float16, + "max_memory": {0: 1024}, + "trust_remote_code": True, + "attn_implementation": "flash_attention_2", + } + + init_kwargs = example_utils._empty_model_init_kwargs(kwargs, torch.bfloat16) + + assert init_kwargs == { + "dtype": torch.bfloat16, + "torch_dtype": torch.float16, + "trust_remote_code": True, + "attn_implementation": "flash_attention_2", + } + + +def test_empty_model_init_dtype_prefers_config_dtype(): + assert ( + example_utils._empty_model_init_dtype( + SimpleNamespace(dtype=torch.float16, torch_dtype=torch.bfloat16) + ) + is torch.float16 + ) + + +def test_empty_model_init_dtype_defaults_to_bfloat16(): + assert example_utils._empty_model_init_dtype(SimpleNamespace()) is torch.bfloat16 + + +def test_from_config_for_empty_weights_calls_with_dtype_first(): + calls = [] + + def from_config(config, **kwargs): + calls.append((config, kwargs, torch.get_default_dtype())) + return "model" + + hf_config = SimpleNamespace() + model_kwargs = {"dtype": torch.bfloat16, "trust_remote_code": True} + + assert ( + example_utils._from_config_for_empty_weights( + from_config, hf_config, model_kwargs, torch.bfloat16 + ) + == "model" + ) + + assert calls == [(hf_config, model_kwargs, torch.get_default_dtype())] + + +def test_from_config_for_empty_weights_retries_without_dtype_with_default_dtype(): + calls = [] + original_dtype = torch.get_default_dtype() + + def from_config(config, **kwargs): + calls.append((kwargs, torch.get_default_dtype())) + if "dtype" in kwargs: + raise TypeError("__init__() got an unexpected keyword argument 'dtype'") + return config + + try: + hf_config = SimpleNamespace() + model_kwargs = {"dtype": torch.bfloat16, "trust_remote_code": True} + + assert ( + example_utils._from_config_for_empty_weights( + from_config, hf_config, model_kwargs, torch.bfloat16 + ) + is hf_config + ) + + assert calls == [ + (model_kwargs, original_dtype), + ({"trust_remote_code": True}, torch.bfloat16), + ] + assert torch.get_default_dtype() is original_dtype + finally: + torch.set_default_dtype(original_dtype) + + +def test_from_config_for_empty_weights_reraises_other_type_error(): + def from_config(config, **kwargs): + raise TypeError("missing required positional argument: 'hidden_size'") + + with pytest.raises(TypeError, match="hidden_size"): + example_utils._from_config_for_empty_weights( + from_config, SimpleNamespace(), {"dtype": torch.bfloat16}, torch.bfloat16 + ) From 7d0bebde3ed29675d5d038b31cdfbcb68d735e41 Mon Sep 17 00:00:00 2001 From: realAsma Date: Mon, 29 Jun 2026 22:50:57 +0000 Subject: [PATCH 2/9] Simplify HF PTQ empty init dtype fix Signed-off-by: realAsma --- examples/hf_ptq/example_utils.py | 62 ++------- tests/examples/hf_ptq/test_example_utils.py | 131 +++++++------------- 2 files changed, 55 insertions(+), 138 deletions(-) diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index 8c51160c38f..63512e525c9 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -598,51 +598,6 @@ def get_original_hf_quant_method(config) -> str | None: return None -def _empty_model_init_kwargs(model_kwargs, torch_dtype): - init_kwargs = model_kwargs.copy() - init_kwargs.pop("max_memory", None) - init_kwargs["dtype"] = torch_dtype - return init_kwargs - - -def _empty_model_init_dtype(hf_config): - torch_dtype = ( - getattr(hf_config, "dtype", None) - or getattr(hf_config, "torch_dtype", None) - or torch.bfloat16 - ) - if isinstance(torch_dtype, str): - torch_dtype = getattr(torch, torch_dtype) - return torch_dtype - - -def _is_unexpected_dtype_kwarg_error(error): - message = str(error) - return "unexpected keyword argument" in message and ( - "'dtype'" in message or '"dtype"' in message - ) - - -def _from_config_for_empty_weights(from_config, hf_config, model_kwargs, torch_dtype): - try: - return from_config(hf_config, **model_kwargs) - except TypeError as error: - if "dtype" not in model_kwargs or not _is_unexpected_dtype_kwarg_error(error): - raise - - model_kwargs_without_dtype = model_kwargs.copy() - model_kwargs_without_dtype.pop("dtype", None) - orig_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch_dtype) - try: - # Some remote-code constructors, such as DeciLMForCausalLM, reject the - # Transformers dtype kwarg even though bf16-sized empty weights are - # still needed for device-map inference. - return from_config(hf_config, **model_kwargs_without_dtype) - finally: - torch.set_default_dtype(orig_dtype) - - def get_model( ckpt_path, device="cuda", @@ -779,13 +734,20 @@ def has_pack_quantized_config(config): with init_empty_weights(include_buffers=True): # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = _empty_model_init_dtype(hf_config) - model_kwargs2 = _empty_model_init_kwargs(model_kwargs, torch_dtype) + torch_dtype = ( + getattr(hf_config, "dtype", None) + or getattr(hf_config, "torch_dtype", None) + or torch.bfloat16 + ) + if isinstance(torch_dtype, str): + torch_dtype = getattr(torch, torch_dtype) + model_kwargs2 = model_kwargs.copy() if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) - model = _from_config_for_empty_weights( - from_config, hf_config, model_kwargs2, torch_dtype - ) + model_kwargs2["torch_dtype"] = torch_dtype + model_kwargs2.pop("dtype", None) + model_kwargs2.pop("max_memory", None) + model = from_config(hf_config, **model_kwargs2) max_memory = get_max_memory() inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index a389b946b23..76e83fcbfbe 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -19,9 +19,9 @@ """ import json +from contextlib import nullcontext from types import SimpleNamespace -import pytest import torch from _test_utils.examples.hf_ptq_example_utils import example_utils from safetensors.torch import save_file @@ -197,93 +197,48 @@ def test_get_original_hf_quant_method_none_for_unquantized(): ) -def test_empty_model_init_kwargs_keeps_dtype_for_general_from_config(): - kwargs = { - "dtype": "auto", - "torch_dtype": torch.float16, - "max_memory": {0: 1024}, - "trust_remote_code": True, - "attn_implementation": "flash_attention_2", - } - - init_kwargs = example_utils._empty_model_init_kwargs(kwargs, torch.bfloat16) - - assert init_kwargs == { - "dtype": torch.bfloat16, - "torch_dtype": torch.float16, - "trust_remote_code": True, - "attn_implementation": "flash_attention_2", - } - - -def test_empty_model_init_dtype_prefers_config_dtype(): - assert ( - example_utils._empty_model_init_dtype( - SimpleNamespace(dtype=torch.float16, torch_dtype=torch.bfloat16) - ) - is torch.float16 +def test_get_model_empty_init_uses_torch_dtype_not_dtype(monkeypatch): + calls = {} + hf_config = SimpleNamespace( + architectures=["DeciLMForCausalLM"], + dtype=torch.float16, + model_type="llama", + torch_dtype=torch.bfloat16, ) - -def test_empty_model_init_dtype_defaults_to_bfloat16(): - assert example_utils._empty_model_init_dtype(SimpleNamespace()) is torch.bfloat16 - - -def test_from_config_for_empty_weights_calls_with_dtype_first(): - calls = [] - - def from_config(config, **kwargs): - calls.append((config, kwargs, torch.get_default_dtype())) - return "model" - - hf_config = SimpleNamespace() - model_kwargs = {"dtype": torch.bfloat16, "trust_remote_code": True} - - assert ( - example_utils._from_config_for_empty_weights( - from_config, hf_config, model_kwargs, torch.bfloat16 - ) - == "model" + class FakeModel: + def eval(self): + calls["eval"] = True + + class FakeAutoModelForCausalLM: + @staticmethod + def from_config(config, **kwargs): + calls["from_config"] = kwargs + assert config is hf_config + assert "dtype" not in kwargs + assert kwargs["torch_dtype"] is torch.float16 + assert "max_memory" not in kwargs + return FakeModel() + + @staticmethod + def from_pretrained(*args, **kwargs): + calls["from_pretrained"] = kwargs + return FakeModel() + + monkeypatch.setattr( + example_utils.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: hf_config, ) - - assert calls == [(hf_config, model_kwargs, torch.get_default_dtype())] - - -def test_from_config_for_empty_weights_retries_without_dtype_with_default_dtype(): - calls = [] - original_dtype = torch.get_default_dtype() - - def from_config(config, **kwargs): - calls.append((kwargs, torch.get_default_dtype())) - if "dtype" in kwargs: - raise TypeError("__init__() got an unexpected keyword argument 'dtype'") - return config - - try: - hf_config = SimpleNamespace() - model_kwargs = {"dtype": torch.bfloat16, "trust_remote_code": True} - - assert ( - example_utils._from_config_for_empty_weights( - from_config, hf_config, model_kwargs, torch.bfloat16 - ) - is hf_config - ) - - assert calls == [ - (model_kwargs, original_dtype), - ({"trust_remote_code": True}, torch.bfloat16), - ] - assert torch.get_default_dtype() is original_dtype - finally: - torch.set_default_dtype(original_dtype) - - -def test_from_config_for_empty_weights_reraises_other_type_error(): - def from_config(config, **kwargs): - raise TypeError("missing required positional argument: 'hidden_size'") - - with pytest.raises(TypeError, match="hidden_size"): - example_utils._from_config_for_empty_weights( - from_config, SimpleNamespace(), {"dtype": torch.bfloat16}, torch.bfloat16 - ) + monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM) + monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False) + monkeypatch.setattr(example_utils, "is_speculative", lambda config: False) + monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext()) + monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024}) + monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0}) + + model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True) + + assert isinstance(model, FakeModel) + assert calls["eval"] + assert calls["from_config"]["trust_remote_code"] is True From 6cb2243976697f631839d5d58be4879b17cfd896 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 02:11:42 +0000 Subject: [PATCH 3/9] Fix HF PTQ real-load dtype kwarg Signed-off-by: realAsma --- examples/hf_ptq/example_utils.py | 4 +++- tests/examples/hf_ptq/test_example_utils.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index 63512e525c9..b306abfad9b 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -767,10 +767,12 @@ def has_pack_quantized_config(config): ) model_kwargs["max_memory"] = max_memory + model_kwargs2 = model_kwargs.copy() + model_kwargs2["torch_dtype"] = model_kwargs2.pop("dtype", "auto") model = auto_model_module.from_pretrained( ckpt_path, device_map=device_map, - **model_kwargs, + **model_kwargs2, ) model.eval() if has_pack_quantized_config(hf_config): diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index 76e83fcbfbe..eb1f9a8b082 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -223,6 +223,8 @@ def from_config(config, **kwargs): @staticmethod def from_pretrained(*args, **kwargs): calls["from_pretrained"] = kwargs + assert "dtype" not in kwargs + assert kwargs["torch_dtype"] == "auto" return FakeModel() monkeypatch.setattr( @@ -242,3 +244,4 @@ def from_pretrained(*args, **kwargs): assert isinstance(model, FakeModel) assert calls["eval"] assert calls["from_config"]["trust_remote_code"] is True + assert calls["from_pretrained"]["trust_remote_code"] is True From 2641497a1926ab01f66771d19dc61fa62ce489d7 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 02:48:02 +0000 Subject: [PATCH 4/9] Drop dtype from HF PTQ final load Signed-off-by: realAsma --- examples/hf_ptq/example_utils.py | 2 +- tests/examples/hf_ptq/test_example_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index b306abfad9b..a2ef8b5ad57 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -768,7 +768,7 @@ def has_pack_quantized_config(config): model_kwargs["max_memory"] = max_memory model_kwargs2 = model_kwargs.copy() - model_kwargs2["torch_dtype"] = model_kwargs2.pop("dtype", "auto") + model_kwargs2.pop("dtype", None) model = auto_model_module.from_pretrained( ckpt_path, device_map=device_map, diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index eb1f9a8b082..ac74e02c9a3 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -197,7 +197,7 @@ def test_get_original_hf_quant_method_none_for_unquantized(): ) -def test_get_model_empty_init_uses_torch_dtype_not_dtype(monkeypatch): +def test_get_model_drops_dtype_from_final_load(monkeypatch): calls = {} hf_config = SimpleNamespace( architectures=["DeciLMForCausalLM"], @@ -224,7 +224,7 @@ def from_config(config, **kwargs): def from_pretrained(*args, **kwargs): calls["from_pretrained"] = kwargs assert "dtype" not in kwargs - assert kwargs["torch_dtype"] == "auto" + assert "torch_dtype" not in kwargs return FakeModel() monkeypatch.setattr( From 97bb15cd63d8a715b08e2382ed2ead44e0c7842d Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 04:30:11 +0000 Subject: [PATCH 5/9] Scope HF PTQ dtype workaround to DeciLM Signed-off-by: realAsma --- examples/hf_ptq/example_utils.py | 17 ++++--- tests/examples/hf_ptq/test_example_utils.py | 52 ++++++++++++++++++++- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index a2ef8b5ad57..a3b1ab5beee 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -731,21 +731,25 @@ def has_pack_quantized_config(config): auto_model_module = getattr(transformers, architecture) from_config = auto_model_module._from_config + is_decilm = "DeciLM" in architecture with init_empty_weights(include_buffers=True): # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = ( + config_dtype = ( getattr(hf_config, "dtype", None) or getattr(hf_config, "torch_dtype", None) or torch.bfloat16 ) - if isinstance(torch_dtype, str): - torch_dtype = getattr(torch, torch_dtype) + if isinstance(config_dtype, str): + config_dtype = getattr(torch, config_dtype) model_kwargs2 = model_kwargs.copy() if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) - model_kwargs2["torch_dtype"] = torch_dtype - model_kwargs2.pop("dtype", None) + if is_decilm: + model_kwargs2["torch_dtype"] = config_dtype + model_kwargs2.pop("dtype", None) + else: + model_kwargs2["dtype"] = config_dtype model_kwargs2.pop("max_memory", None) model = from_config(hf_config, **model_kwargs2) @@ -768,7 +772,8 @@ def has_pack_quantized_config(config): model_kwargs["max_memory"] = max_memory model_kwargs2 = model_kwargs.copy() - model_kwargs2.pop("dtype", None) + if is_decilm: + model_kwargs2.pop("dtype", None) model = auto_model_module.from_pretrained( ckpt_path, device_map=device_map, diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index ac74e02c9a3..f78f08a25aa 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -197,7 +197,7 @@ def test_get_original_hf_quant_method_none_for_unquantized(): ) -def test_get_model_drops_dtype_from_final_load(monkeypatch): +def test_get_model_uses_torch_dtype_only_for_decilm(monkeypatch): calls = {} hf_config = SimpleNamespace( architectures=["DeciLMForCausalLM"], @@ -245,3 +245,53 @@ def from_pretrained(*args, **kwargs): assert calls["eval"] assert calls["from_config"]["trust_remote_code"] is True assert calls["from_pretrained"]["trust_remote_code"] is True + + +def test_get_model_uses_dtype_for_non_decilm(monkeypatch): + calls = {} + hf_config = SimpleNamespace( + architectures=["LlamaForCausalLM"], + dtype=torch.float16, + model_type="llama", + torch_dtype=torch.bfloat16, + ) + + class FakeModel: + def eval(self): + calls["eval"] = True + + class FakeLlamaForCausalLM: + @staticmethod + def _from_config(config, **kwargs): + calls["from_config"] = kwargs + assert config is hf_config + assert kwargs["dtype"] is torch.float16 + assert "torch_dtype" not in kwargs + assert "max_memory" not in kwargs + return FakeModel() + + @staticmethod + def from_pretrained(*args, **kwargs): + calls["from_pretrained"] = kwargs + assert kwargs["dtype"] == "auto" + assert "torch_dtype" not in kwargs + return FakeModel() + + monkeypatch.setattr( + example_utils.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: hf_config, + ) + monkeypatch.setattr(example_utils.transformers, "LlamaForCausalLM", FakeLlamaForCausalLM) + monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False) + monkeypatch.setattr(example_utils, "is_speculative", lambda config: False) + monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext()) + monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024}) + monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0}) + + model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True) + + assert isinstance(model, FakeModel) + assert calls["eval"] + assert "trust_remote_code" not in calls["from_config"] + assert calls["from_pretrained"]["trust_remote_code"] is True From acb6e702e0877f08bbb4c76564be219601d3092c Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 14:11:56 +0000 Subject: [PATCH 6/9] Fold HF PTQ dtype test cases Signed-off-by: realAsma --- tests/examples/hf_ptq/test_example_utils.py | 74 +++++++++------------ 1 file changed, 33 insertions(+), 41 deletions(-) diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index f78f08a25aa..de60ecec850 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -22,6 +22,7 @@ from contextlib import nullcontext from types import SimpleNamespace +import pytest import torch from _test_utils.examples.hf_ptq_example_utils import example_utils from safetensors.torch import save_file @@ -197,10 +198,28 @@ def test_get_original_hf_quant_method_none_for_unquantized(): ) -def test_get_model_uses_torch_dtype_only_for_decilm(monkeypatch): +@pytest.mark.parametrize( + ( + "architecture", + "model_class_name", + "expected_config_dtype_kwarg", + "unexpected_config_dtype_kwarg", + ), + [ + ("DeciLMForCausalLM", "AutoModelForCausalLM", "torch_dtype", "dtype"), + ("LlamaForCausalLM", "LlamaForCausalLM", "dtype", "torch_dtype"), + ], +) +def test_get_model_uses_expected_dtype_kwarg( + monkeypatch, + architecture, + model_class_name, + expected_config_dtype_kwarg, + unexpected_config_dtype_kwarg, +): calls = {} hf_config = SimpleNamespace( - architectures=["DeciLMForCausalLM"], + architectures=[architecture], dtype=torch.float16, model_type="llama", torch_dtype=torch.bfloat16, @@ -215,8 +234,8 @@ class FakeAutoModelForCausalLM: def from_config(config, **kwargs): calls["from_config"] = kwargs assert config is hf_config - assert "dtype" not in kwargs - assert kwargs["torch_dtype"] is torch.float16 + assert kwargs[expected_config_dtype_kwarg] is torch.float16 + assert unexpected_config_dtype_kwarg not in kwargs assert "max_memory" not in kwargs return FakeModel() @@ -227,46 +246,13 @@ def from_pretrained(*args, **kwargs): assert "torch_dtype" not in kwargs return FakeModel() - monkeypatch.setattr( - example_utils.AutoConfig, - "from_pretrained", - lambda *args, **kwargs: hf_config, - ) - monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM) - monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False) - monkeypatch.setattr(example_utils, "is_speculative", lambda config: False) - monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext()) - monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024}) - monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0}) - - model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True) - - assert isinstance(model, FakeModel) - assert calls["eval"] - assert calls["from_config"]["trust_remote_code"] is True - assert calls["from_pretrained"]["trust_remote_code"] is True - - -def test_get_model_uses_dtype_for_non_decilm(monkeypatch): - calls = {} - hf_config = SimpleNamespace( - architectures=["LlamaForCausalLM"], - dtype=torch.float16, - model_type="llama", - torch_dtype=torch.bfloat16, - ) - - class FakeModel: - def eval(self): - calls["eval"] = True - class FakeLlamaForCausalLM: @staticmethod def _from_config(config, **kwargs): calls["from_config"] = kwargs assert config is hf_config - assert kwargs["dtype"] is torch.float16 - assert "torch_dtype" not in kwargs + assert kwargs[expected_config_dtype_kwarg] is torch.float16 + assert unexpected_config_dtype_kwarg not in kwargs assert "max_memory" not in kwargs return FakeModel() @@ -282,7 +268,10 @@ def from_pretrained(*args, **kwargs): "from_pretrained", lambda *args, **kwargs: hf_config, ) - monkeypatch.setattr(example_utils.transformers, "LlamaForCausalLM", FakeLlamaForCausalLM) + if model_class_name == "AutoModelForCausalLM": + monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM) + else: + monkeypatch.setattr(example_utils.transformers, model_class_name, FakeLlamaForCausalLM) monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False) monkeypatch.setattr(example_utils, "is_speculative", lambda config: False) monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext()) @@ -293,5 +282,8 @@ def from_pretrained(*args, **kwargs): assert isinstance(model, FakeModel) assert calls["eval"] - assert "trust_remote_code" not in calls["from_config"] + if expected_config_dtype_kwarg == "torch_dtype": + assert calls["from_config"]["trust_remote_code"] is True + else: + assert "trust_remote_code" not in calls["from_config"] assert calls["from_pretrained"]["trust_remote_code"] is True From d768850a0d56e0fcaeaef5d99dfca43c434ea0b1 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 16:52:24 +0000 Subject: [PATCH 7/9] Reuse HF PTQ fake model config assertions Signed-off-by: realAsma --- tests/examples/hf_ptq/test_example_utils.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index de60ecec850..0052bd9929f 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -246,15 +246,8 @@ def from_pretrained(*args, **kwargs): assert "torch_dtype" not in kwargs return FakeModel() - class FakeLlamaForCausalLM: - @staticmethod - def _from_config(config, **kwargs): - calls["from_config"] = kwargs - assert config is hf_config - assert kwargs[expected_config_dtype_kwarg] is torch.float16 - assert unexpected_config_dtype_kwarg not in kwargs - assert "max_memory" not in kwargs - return FakeModel() + class FakeLlamaForCausalLM(FakeAutoModelForCausalLM): + _from_config = FakeAutoModelForCausalLM.from_config @staticmethod def from_pretrained(*args, **kwargs): From 22b37ac19b70a1c737c3d05fa93b0357aa17219d Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 17:29:06 +0000 Subject: [PATCH 8/9] Force DeciLM dtype fallback in HF PTQ test Signed-off-by: realAsma --- tests/examples/hf_ptq/test_example_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index 0052bd9929f..5d1b698fce2 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -263,6 +263,7 @@ def from_pretrained(*args, **kwargs): ) if model_class_name == "AutoModelForCausalLM": monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM) + monkeypatch.delattr(example_utils.transformers, architecture, raising=False) else: monkeypatch.setattr(example_utils.transformers, model_class_name, FakeLlamaForCausalLM) monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False) From 11bae59915715069e7d6ad159970eca77aaf0814 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 30 Jun 2026 18:19:18 +0000 Subject: [PATCH 9/9] Resolve HF PTQ PR merge conflict Signed-off-by: realAsma --- examples/hf_ptq/example_utils.py | 29 +++++++++++++++-- tests/examples/hf_ptq/test_example_utils.py | 35 +++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index a3b1ab5beee..dbb598038e5 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -598,6 +598,25 @@ def get_original_hf_quant_method(config) -> str | None: return None +def _resolve_init_config(hf_config, auto_model_module, ckpt_path, config_kwargs): + """Re-derive a built-in config when a remote-code config is used with a built-in model + class, so it matches the model definition's version; fall back to hf_config otherwise. + """ + if auto_model_module in [AutoModelForCausalLM, AutoModel]: + return hf_config + if not type(hf_config).__module__.startswith("transformers_modules"): + return hf_config + builtin_config_kwargs = {k: v for k, v in config_kwargs.items() if k != "trust_remote_code"} + try: + return AutoConfig.from_pretrained(ckpt_path, **builtin_config_kwargs) + except Exception as e: + warnings.warn( + f"Could not re-derive a built-in config for {ckpt_path} ({e}); using the " + "remote-code config for device-map inference." + ) + return hf_config + + def get_model( ckpt_path, device="cuda", @@ -732,12 +751,16 @@ def has_pack_quantized_config(config): from_config = auto_model_module._from_config is_decilm = "DeciLM" in architecture + config_for_init = _resolve_init_config( + hf_config, auto_model_module, ckpt_path, config_kwargs + ) + with init_empty_weights(include_buffers=True): # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. config_dtype = ( - getattr(hf_config, "dtype", None) - or getattr(hf_config, "torch_dtype", None) + getattr(config_for_init, "dtype", None) + or getattr(config_for_init, "torch_dtype", None) or torch.bfloat16 ) if isinstance(config_dtype, str): @@ -751,7 +774,7 @@ def has_pack_quantized_config(config): else: model_kwargs2["dtype"] = config_dtype model_kwargs2.pop("max_memory", None) - model = from_config(hf_config, **model_kwargs2) + model = from_config(config_for_init, **model_kwargs2) max_memory = get_max_memory() inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index 5d1b698fce2..ec3c6a31a5c 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -21,6 +21,7 @@ import json from contextlib import nullcontext from types import SimpleNamespace +from unittest.mock import patch import pytest import torch @@ -198,6 +199,40 @@ def test_get_original_hf_quant_method_none_for_unquantized(): ) +# ---------- _resolve_init_config --------------------------------------------- + + +def _remote_config(): + # Config whose class module lives under "transformers_modules" (remote code). + cls = type("_RemoteConfig", (), {"__module__": "transformers_modules.ckpt.config"}) + return cls() + + +def test_resolve_init_config_rederives_for_remote_config(): + builtin_cfg = SimpleNamespace() + with patch.object( + example_utils.AutoConfig, "from_pretrained", return_value=builtin_cfg + ) as mock: + out = example_utils._resolve_init_config( + _remote_config(), object, "/ckpt", {"trust_remote_code": True} + ) + assert out is builtin_cfg + mock.assert_called_once_with("/ckpt") # trust_remote_code stripped + + +def test_resolve_init_config_keeps_non_remote_config(): + cfg = SimpleNamespace() # module is "types", not remote + with patch.object(example_utils.AutoConfig, "from_pretrained") as mock: + assert example_utils._resolve_init_config(cfg, object, "/ckpt", {}) is cfg + mock.assert_not_called() + + +def test_resolve_init_config_falls_back_when_rederive_raises(): + cfg = _remote_config() + with patch.object(example_utils.AutoConfig, "from_pretrained", side_effect=ValueError()): + assert example_utils._resolve_init_config(cfg, object, "/ckpt", {}) is cfg + + @pytest.mark.parametrize( ( "architecture",