Skip to content
20 changes: 17 additions & 3 deletions examples/hf_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,14 +731,25 @@ def has_pack_quantized_config(config):
auto_model_module = getattr(transformers, architecture)
from_config = auto_model_module._from_config

is_decilm = "DeciLM" in architecture

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make this a general WAR instead of DiciLM specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bot comment.

Thanks Wei-Ming. I agree this may be generalizable to other older remote-code models with the same Transformers 5+ incompatibility, but finding and validating those models would be a broader follow-up. For this PR, I would like to keep the fix scoped to the observed Llama Nemotron / DeciLM failure since broader remote-code fallback support is lower value and would need dedicated coverage.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bot comment.

Thanks Wei-Ming. I agree this could probably be generalized to older remote-code models with the same constructor mismatch, but that would require identifying and validating the affected model set.

For this RC bug, I would keep the fix scoped to Llama Nemotron / DeciLM because that is the reported failure and the broader remote-code support case is lower value without dedicated coverage. I can follow up separately if we find more models with the same failure.

with init_empty_weights(include_buffers=True):
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
config_dtype = (
getattr(hf_config, "dtype", None)
or getattr(hf_config, "torch_dtype", None)
or torch.bfloat16
)
if isinstance(config_dtype, str):
config_dtype = getattr(torch, config_dtype)
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["dtype"] = torch_dtype
if is_decilm:
model_kwargs2["torch_dtype"] = config_dtype
model_kwargs2.pop("dtype", None)
else:
model_kwargs2["dtype"] = config_dtype
model_kwargs2.pop("max_memory", None)
model = from_config(hf_config, **model_kwargs2)

Expand All @@ -760,10 +771,13 @@ def has_pack_quantized_config(config):
)
model_kwargs["max_memory"] = max_memory

model_kwargs2 = model_kwargs.copy()
if is_decilm:
model_kwargs2.pop("dtype", None)
Comment on lines +797 to +799

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BB: can we use the same if is_decilm: model_kwargs2["torch_dtype"] = config_dtype model_kwargs2.pop("dtype", None) behavior here as well to be safe? we can create a helper to get the config with the correct dtype.

model = auto_model_module.from_pretrained(
ckpt_path,
device_map=device_map,
**model_kwargs,
**model_kwargs2,
)
model.eval()
if has_pack_quantized_config(hf_config):
Expand Down
101 changes: 101 additions & 0 deletions tests/examples/hf_ptq/test_example_utils.py
Comment thread
realAsma marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import json
from contextlib import nullcontext
from types import SimpleNamespace

import torch
Expand Down Expand Up @@ -194,3 +195,103 @@ def test_get_original_hf_quant_method_none_for_unquantized():
example_utils.get_original_hf_quant_method(SimpleNamespace(quantization_config=None))
is None
)


def test_get_model_uses_torch_dtype_only_for_decilm(monkeypatch):
calls = {}
hf_config = SimpleNamespace(
architectures=["DeciLMForCausalLM"],
dtype=torch.float16,
model_type="llama",
torch_dtype=torch.bfloat16,
)

class FakeModel:
def eval(self):
calls["eval"] = True

class FakeAutoModelForCausalLM:
@staticmethod
def from_config(config, **kwargs):
calls["from_config"] = kwargs
assert config is hf_config
assert "dtype" not in kwargs
assert kwargs["torch_dtype"] is torch.float16
assert "max_memory" not in kwargs
return FakeModel()

@staticmethod
def from_pretrained(*args, **kwargs):
calls["from_pretrained"] = kwargs
assert "dtype" not in kwargs
assert "torch_dtype" not in kwargs
return FakeModel()

monkeypatch.setattr(
example_utils.AutoConfig,
"from_pretrained",
lambda *args, **kwargs: hf_config,
)
monkeypatch.setattr(example_utils, "AutoModelForCausalLM", FakeAutoModelForCausalLM)
monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False)
monkeypatch.setattr(example_utils, "is_speculative", lambda config: False)
monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext())
monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024})
monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0})

model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True)

assert isinstance(model, FakeModel)
assert calls["eval"]
assert calls["from_config"]["trust_remote_code"] is True
assert calls["from_pretrained"]["trust_remote_code"] is True


def test_get_model_uses_dtype_for_non_decilm(monkeypatch):
calls = {}
hf_config = SimpleNamespace(
architectures=["LlamaForCausalLM"],
dtype=torch.float16,
model_type="llama",
torch_dtype=torch.bfloat16,
)

class FakeModel:
def eval(self):
calls["eval"] = True

class FakeLlamaForCausalLM:
Comment thread
realAsma marked this conversation as resolved.
Outdated
@staticmethod
def _from_config(config, **kwargs):
calls["from_config"] = kwargs
assert config is hf_config
assert kwargs["dtype"] is torch.float16
assert "torch_dtype" not in kwargs
assert "max_memory" not in kwargs
return FakeModel()

@staticmethod
def from_pretrained(*args, **kwargs):
calls["from_pretrained"] = kwargs
assert kwargs["dtype"] == "auto"
assert "torch_dtype" not in kwargs
return FakeModel()

monkeypatch.setattr(
example_utils.AutoConfig,
"from_pretrained",
lambda *args, **kwargs: hf_config,
)
monkeypatch.setattr(example_utils.transformers, "LlamaForCausalLM", FakeLlamaForCausalLM)
monkeypatch.setattr(example_utils, "is_nemotron_vl", lambda config: False)
monkeypatch.setattr(example_utils, "is_speculative", lambda config: False)
monkeypatch.setattr(example_utils, "init_empty_weights", lambda include_buffers: nullcontext())
monkeypatch.setattr(example_utils, "get_max_memory", lambda: {0: 1024})
monkeypatch.setattr(example_utils, "infer_auto_device_map", lambda model, max_memory: {"": 0})

model = example_utils.get_model("checkpoint", device="cpu", trust_remote_code=True)

assert isinstance(model, FakeModel)
assert calls["eval"]
assert "trust_remote_code" not in calls["from_config"]
assert calls["from_pretrained"]["trust_remote_code"] is True