diff --git a/examples/hf_ptq/example_utils.py b/examples/hf_ptq/example_utils.py index dbb598038e5..9e8dea5f107 100755 --- a/examples/hf_ptq/example_utils.py +++ b/examples/hf_ptq/example_utils.py @@ -617,6 +617,25 @@ def _resolve_init_config(hf_config, auto_model_module, ckpt_path, config_kwargs) return hf_config +def _get_config_dtype(config): + config_dtype = ( + getattr(config, "dtype", None) or getattr(config, "torch_dtype", None) or torch.bfloat16 + ) + if isinstance(config_dtype, str): + config_dtype = getattr(torch, config_dtype) + return config_dtype + + +def _apply_dtype_to_config(model_kwargs, config_dtype, architecture, apply_config_dtype=False): + model_kwargs = model_kwargs.copy() + if "DeciLM" in architecture: + model_kwargs["torch_dtype"] = config_dtype + model_kwargs.pop("dtype", None) + elif apply_config_dtype: + model_kwargs["dtype"] = config_dtype + return model_kwargs + + def get_model( ckpt_path, device="cuda", @@ -750,7 +769,6 @@ def has_pack_quantized_config(config): auto_model_module = getattr(transformers, architecture) from_config = auto_model_module._from_config - is_decilm = "DeciLM" in architecture config_for_init = _resolve_init_config( hf_config, auto_model_module, ckpt_path, config_kwargs ) @@ -758,21 +776,12 @@ def has_pack_quantized_config(config): with init_empty_weights(include_buffers=True): # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - config_dtype = ( - getattr(config_for_init, "dtype", None) - or getattr(config_for_init, "torch_dtype", None) - or torch.bfloat16 + config_dtype = _get_config_dtype(config_for_init) + model_kwargs2 = _apply_dtype_to_config( + model_kwargs, config_dtype, architecture, apply_config_dtype=True ) - if isinstance(config_dtype, str): - config_dtype = getattr(torch, config_dtype) - model_kwargs2 = model_kwargs.copy() if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) - if is_decilm: - model_kwargs2["torch_dtype"] = config_dtype - model_kwargs2.pop("dtype", None) - else: - model_kwargs2["dtype"] = config_dtype model_kwargs2.pop("max_memory", None) model = from_config(config_for_init, **model_kwargs2) @@ -794,9 +803,7 @@ def has_pack_quantized_config(config): ) model_kwargs["max_memory"] = max_memory - model_kwargs2 = model_kwargs.copy() - if is_decilm: - model_kwargs2.pop("dtype", None) + model_kwargs2 = _apply_dtype_to_config(model_kwargs, config_dtype, architecture) model = auto_model_module.from_pretrained( ckpt_path, device_map=device_map, diff --git a/tests/examples/hf_ptq/test_example_utils.py b/tests/examples/hf_ptq/test_example_utils.py index ec3c6a31a5c..00621ec6125 100644 --- a/tests/examples/hf_ptq/test_example_utils.py +++ b/tests/examples/hf_ptq/test_example_utils.py @@ -278,7 +278,7 @@ def from_config(config, **kwargs): def from_pretrained(*args, **kwargs): calls["from_pretrained"] = kwargs assert "dtype" not in kwargs - assert "torch_dtype" not in kwargs + assert kwargs["torch_dtype"] is torch.float16 return FakeModel() class FakeLlamaForCausalLM(FakeAutoModelForCausalLM):