Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/hf_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,11 +734,18 @@ def has_pack_quantized_config(config):
with init_empty_weights(include_buffers=True):
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
torch_dtype = (
getattr(hf_config, "dtype", None)
or getattr(hf_config, "torch_dtype", None)
or torch.bfloat16
)
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["dtype"] = torch_dtype
model_kwargs2["torch_dtype"] = torch_dtype
model_kwargs2.pop("dtype", None)
model_kwargs2.pop("max_memory", None)
model = from_config(hf_config, **model_kwargs2)

Expand Down
48 changes: 48 additions & 0 deletions tests/examples/hf_ptq/test_example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import json
from contextlib import nullcontext
from types import SimpleNamespace

import torch
Expand Down Expand Up @@ -194,3 +195,50 @@ def test_get_original_hf_quant_method_none_for_unquantized():
example_utils.get_original_hf_quant_method(SimpleNamespace(quantization_config=None))
is None
)


def test_get_model_empty_init_uses_torch_dtype_not_dtype(monkeypatch):
calls = {}
hf_config = SimpleNamespace(
architectures=["DeciLMForCausalLM"],
dtype=torch.float16,
model_type="llama",
torch_dtype=torch.bfloat16,
)

class FakeModel:
def eval(self):
calls["eval"] = True

class FakeAutoModelForCausalLM:
@staticmethod
def from_config(config, **kwargs):
calls["from_config"] = kwargs
assert config is hf_config
assert "dtype" not in kwargs
assert kwargs["torch_dtype"] is torch.float16
assert "max_memory" not in kwargs
return FakeModel()

@staticmethod
def from_pretrained(*args, **kwargs):
calls["from_pretrained"] = kwargs
return FakeModel()

monkeypatch.setattr(
example_utils.AutoConfig,
"from_pretrained",
lambda *args, **kwargs: hf_config,
)
monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM)
monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False)
monkeypatch.setattr(example_utils, "is_speculative", lambda config: False)
monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext())
monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024})
monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0})

model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True)

assert isinstance(model, FakeModel)
assert calls["eval"]
assert calls["from_config"]["trust_remote_code"] is True
Loading