Skip to content
45 changes: 41 additions & 4 deletions examples/hf_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,25 @@ def get_original_hf_quant_method(config) -> str | None:
return None


def _resolve_init_config(hf_config, auto_model_module, ckpt_path, config_kwargs):
"""Re-derive a built-in config when a remote-code config is used with a built-in model
class, so it matches the model definition's version; fall back to hf_config otherwise.
"""
if auto_model_module in [AutoModelForCausalLM, AutoModel]:
return hf_config
if not type(hf_config).__module__.startswith("transformers_modules"):
return hf_config
builtin_config_kwargs = {k: v for k, v in config_kwargs.items() if k != "trust_remote_code"}
try:
return AutoConfig.from_pretrained(ckpt_path, **builtin_config_kwargs)
except Exception as e:
warnings.warn(
f"Could not re-derive a built-in config for {ckpt_path} ({e}); using the "
"remote-code config for device-map inference."
)
return hf_config


def get_model(
ckpt_path,
device="cuda",
Expand Down Expand Up @@ -731,16 +750,31 @@ def has_pack_quantized_config(config):
auto_model_module = getattr(transformers, architecture)
from_config = auto_model_module._from_config

is_decilm = "DeciLM" in architecture

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make this a general WAR instead of DiciLM specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bot comment.

Thanks Wei-Ming. I agree this may be generalizable to other older remote-code models with the same Transformers 5+ incompatibility, but finding and validating those models would be a broader follow-up. For this PR, I would like to keep the fix scoped to the observed Llama Nemotron / DeciLM failure since broader remote-code fallback support is lower value and would need dedicated coverage.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bot comment.

Thanks Wei-Ming. I agree this could probably be generalized to older remote-code models with the same constructor mismatch, but that would require identifying and validating the affected model set.

For this RC bug, I would keep the fix scoped to Llama Nemotron / DeciLM because that is the reported failure and the broader remote-code support case is lower value without dedicated coverage. I can follow up separately if we find more models with the same failure.

config_for_init = _resolve_init_config(
hf_config, auto_model_module, ckpt_path, config_kwargs
)

with init_empty_weights(include_buffers=True):
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
config_dtype = (
getattr(config_for_init, "dtype", None)
or getattr(config_for_init, "torch_dtype", None)
or torch.bfloat16
)
if isinstance(config_dtype, str):
config_dtype = getattr(torch, config_dtype)
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["dtype"] = torch_dtype
if is_decilm:
model_kwargs2["torch_dtype"] = config_dtype
model_kwargs2.pop("dtype", None)
else:
model_kwargs2["dtype"] = config_dtype
model_kwargs2.pop("max_memory", None)
model = from_config(hf_config, **model_kwargs2)
model = from_config(config_for_init, **model_kwargs2)

max_memory = get_max_memory()
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
Expand All @@ -760,10 +794,13 @@ def has_pack_quantized_config(config):
)
model_kwargs["max_memory"] = max_memory

model_kwargs2 = model_kwargs.copy()
if is_decilm:
model_kwargs2.pop("dtype", None)
Comment on lines +797 to +799

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BB: can we use the same if is_decilm: model_kwargs2["torch_dtype"] = config_dtype model_kwargs2.pop("dtype", None) behavior here as well to be safe? we can create a helper to get the config with the correct dtype.

model = auto_model_module.from_pretrained(
ckpt_path,
device_map=device_map,
**model_kwargs,
**model_kwargs2,
)
model.eval()
if has_pack_quantized_config(hf_config):
Expand Down
122 changes: 122 additions & 0 deletions tests/examples/hf_ptq/test_example_utils.py
Comment thread
realAsma marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
"""

import json
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import patch

import pytest
import torch
from _test_utils.examples.hf_ptq_example_utils import example_utils
from safetensors.torch import save_file
Expand Down Expand Up @@ -194,3 +197,122 @@ def test_get_original_hf_quant_method_none_for_unquantized():
example_utils.get_original_hf_quant_method(SimpleNamespace(quantization_config=None))
is None
)


# ---------- _resolve_init_config ---------------------------------------------


def _remote_config():
# Config whose class module lives under "transformers_modules" (remote code).
cls = type("_RemoteConfig", (), {"__module__": "transformers_modules.ckpt.config"})
return cls()


def test_resolve_init_config_rederives_for_remote_config():
builtin_cfg = SimpleNamespace()
with patch.object(
example_utils.AutoConfig, "from_pretrained", return_value=builtin_cfg
) as mock:
out = example_utils._resolve_init_config(
_remote_config(), object, "/ckpt", {"trust_remote_code": True}
)
assert out is builtin_cfg
mock.assert_called_once_with("/ckpt") # trust_remote_code stripped


def test_resolve_init_config_keeps_non_remote_config():
cfg = SimpleNamespace() # module is "types", not remote
with patch.object(example_utils.AutoConfig, "from_pretrained") as mock:
assert example_utils._resolve_init_config(cfg, object, "/ckpt", {}) is cfg
mock.assert_not_called()


def test_resolve_init_config_falls_back_when_rederive_raises():
cfg = _remote_config()
with patch.object(example_utils.AutoConfig, "from_pretrained", side_effect=ValueError()):
assert example_utils._resolve_init_config(cfg, object, "/ckpt", {}) is cfg


@pytest.mark.parametrize(
(
"architecture",
"model_class_name",
"expected_config_dtype_kwarg",
"unexpected_config_dtype_kwarg",
),
[
("DeciLMForCausalLM", "AutoModelForCausalLM", "torch_dtype", "dtype"),
("LlamaForCausalLM", "LlamaForCausalLM", "dtype", "torch_dtype"),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
],
)
def test_get_model_uses_expected_dtype_kwarg(
monkeypatch,
architecture,
model_class_name,
expected_config_dtype_kwarg,
unexpected_config_dtype_kwarg,
):
calls = {}
hf_config = SimpleNamespace(
architectures=[architecture],
dtype=torch.float16,
model_type="llama",
torch_dtype=torch.bfloat16,
)

class FakeModel:
def eval(self):
calls["eval"] = True

class FakeAutoModelForCausalLM:
@staticmethod
def from_config(config, **kwargs):
calls["from_config"] = kwargs
assert config is hf_config
assert kwargs[expected_config_dtype_kwarg] is torch.float16
assert unexpected_config_dtype_kwarg not in kwargs
assert "max_memory" not in kwargs
return FakeModel()

@staticmethod
def from_pretrained(*args, **kwargs):
calls["from_pretrained"] = kwargs
assert "dtype" not in kwargs
assert "torch_dtype" not in kwargs
return FakeModel()

class FakeLlamaForCausalLM(FakeAutoModelForCausalLM):
_from_config = FakeAutoModelForCausalLM.from_config

@staticmethod
def from_pretrained(*args, **kwargs):
calls["from_pretrained"] = kwargs
assert kwargs["dtype"] == "auto"
assert "torch_dtype" not in kwargs
return FakeModel()

monkeypatch.setattr(
example_utils.AutoConfig,
"from_pretrained",
lambda *args, **kwargs: hf_config,
)
if model_class_name == "AutoModelForCausalLM":
monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM)
monkeypatch.delattr(example_utils.transformers, architecture, raising=False)
else:
monkeypatch.setattr(example_utils.transformers, model_class_name, FakeLlamaForCausalLM)
monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False)
monkeypatch.setattr(example_utils, "is_speculative", lambda config: False)
monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext())
monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024})
monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0})

model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True)

assert isinstance(model, FakeModel)
assert calls["eval"]
if expected_config_dtype_kwarg == "torch_dtype":
assert calls["from_config"]["trust_remote_code"] is True
else:
assert "trust_remote_code" not in calls["from_config"]
assert calls["from_pretrained"]["trust_remote_code"] is True
Loading