Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions examples/hf_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,25 @@ def get_original_hf_quant_method(config) -> str | None:
return None


def _resolve_init_config(hf_config, auto_model_module, ckpt_path, config_kwargs):
"""Re-derive a built-in config when a remote-code config is used with a built-in model
class, so it matches the model definition's version; fall back to hf_config otherwise.
"""
if auto_model_module in [AutoModelForCausalLM, AutoModel]:
return hf_config
if not type(hf_config).__module__.startswith("transformers_modules"):
return hf_config
builtin_config_kwargs = {k: v for k, v in config_kwargs.items() if k != "trust_remote_code"}
try:
return AutoConfig.from_pretrained(ckpt_path, **builtin_config_kwargs)
except Exception as e:
warnings.warn(
f"Could not re-derive a built-in config for {ckpt_path} ({e}); using the "
"remote-code config for device-map inference."
)
return hf_config


def get_model(
ckpt_path,
device="cuda",
Expand Down Expand Up @@ -731,16 +750,20 @@ def has_pack_quantized_config(config):
auto_model_module = getattr(transformers, architecture)
from_config = auto_model_module._from_config

config_for_init = _resolve_init_config(
hf_config, auto_model_module, ckpt_path, config_kwargs
)

with init_empty_weights(include_buffers=True):
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
torch_dtype = getattr(config_for_init, "torch_dtype", torch.bfloat16)
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["dtype"] = torch_dtype
model_kwargs2.pop("max_memory", None)
model = from_config(hf_config, **model_kwargs2)
model = from_config(config_for_init, **model_kwargs2)

max_memory = get_max_memory()
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
Expand Down
35 changes: 35 additions & 0 deletions tests/examples/hf_ptq/test_example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import json
from types import SimpleNamespace
from unittest.mock import patch

import torch
from _test_utils.examples.hf_ptq_example_utils import example_utils
Expand Down Expand Up @@ -194,3 +195,37 @@ def test_get_original_hf_quant_method_none_for_unquantized():
example_utils.get_original_hf_quant_method(SimpleNamespace(quantization_config=None))
is None
)


# ---------- _resolve_init_config ---------------------------------------------


def _remote_config():
# Config whose class module lives under "transformers_modules" (remote code).
cls = type("_RemoteConfig", (), {"__module__": "transformers_modules.ckpt.config"})
return cls()


def test_resolve_init_config_rederives_for_remote_config():
builtin_cfg = SimpleNamespace()
with patch.object(
example_utils.AutoConfig, "from_pretrained", return_value=builtin_cfg
) as mock:
out = example_utils._resolve_init_config(
_remote_config(), object, "/ckpt", {"trust_remote_code": True}
)
assert out is builtin_cfg
mock.assert_called_once_with("/ckpt") # trust_remote_code stripped


def test_resolve_init_config_keeps_non_remote_config():
cfg = SimpleNamespace() # module is "types", not remote
with patch.object(example_utils.AutoConfig, "from_pretrained") as mock:
assert example_utils._resolve_init_config(cfg, object, "/ckpt", {}) is cfg
mock.assert_not_called()


def test_resolve_init_config_falls_back_when_rederive_raises():
cfg = _remote_config()
with patch.object(example_utils.AutoConfig, "from_pretrained", side_effect=ValueError()):
assert example_utils._resolve_init_config(cfg, object, "/ckpt", {}) is cfg
Loading