Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions examples/hf_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,25 @@ def _resolve_init_config(hf_config, auto_model_module, ckpt_path, config_kwargs)
return hf_config


def _get_config_dtype(config):
config_dtype = (
getattr(config, "dtype", None) or getattr(config, "torch_dtype", None) or torch.bfloat16
)
if isinstance(config_dtype, str):
config_dtype = getattr(torch, config_dtype)
return config_dtype


def _apply_dtype_to_config(model_kwargs, config_dtype, is_decilm, apply_config_dtype=False):
model_kwargs = model_kwargs.copy()
if is_decilm:
model_kwargs["torch_dtype"] = config_dtype
model_kwargs.pop("dtype", None)
elif apply_config_dtype:
model_kwargs["dtype"] = config_dtype
return model_kwargs


def get_model(
ckpt_path,
device="cuda",
Expand Down Expand Up @@ -758,21 +777,12 @@ def has_pack_quantized_config(config):
with init_empty_weights(include_buffers=True):
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
config_dtype = (
getattr(config_for_init, "dtype", None)
or getattr(config_for_init, "torch_dtype", None)
or torch.bfloat16
config_dtype = _get_config_dtype(config_for_init)
model_kwargs2 = _apply_dtype_to_config(
model_kwargs, config_dtype, is_decilm, apply_config_dtype=True
Comment thread
realAsma marked this conversation as resolved.
Outdated
)
if isinstance(config_dtype, str):
config_dtype = getattr(torch, config_dtype)
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
if is_decilm:
model_kwargs2["torch_dtype"] = config_dtype
model_kwargs2.pop("dtype", None)
else:
model_kwargs2["dtype"] = config_dtype
model_kwargs2.pop("max_memory", None)
model = from_config(config_for_init, **model_kwargs2)

Expand All @@ -794,9 +804,7 @@ def has_pack_quantized_config(config):
)
model_kwargs["max_memory"] = max_memory

model_kwargs2 = model_kwargs.copy()
if is_decilm:
model_kwargs2.pop("dtype", None)
model_kwargs2 = _apply_dtype_to_config(model_kwargs, config_dtype, is_decilm)
model = auto_model_module.from_pretrained(
ckpt_path,
device_map=device_map,
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/hf_ptq/test_example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def from_config(config, **kwargs):
def from_pretrained(*args, **kwargs):
calls["from_pretrained"] = kwargs
assert "dtype" not in kwargs
assert "torch_dtype" not in kwargs
assert kwargs["torch_dtype"] is torch.float16
return FakeModel()

class FakeLlamaForCausalLM(FakeAutoModelForCausalLM):
Expand Down
Loading