Fix HF PTQ empty-init dtype kwargs#1853
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1853 +/- ##
=======================================
Coverage 77.40% 77.40%
=======================================
Files 515 515
Lines 57118 57118
=======================================
Hits 44214 44214
Misses 12904 12904
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
| def _is_unexpected_dtype_kwarg_error(error): | ||
| message = str(error) | ||
| return "unexpected keyword argument" in message and ( | ||
| "'dtype'" in message or '"dtype"' in message | ||
| ) |
There was a problem hiding this comment.
no need of this helper method
| # unless specified by the hf_config. | ||
| torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) | ||
| model_kwargs2 = model_kwargs.copy() | ||
| torch_dtype = _empty_model_init_dtype(hf_config) | ||
| model_kwargs2 = _empty_model_init_kwargs(model_kwargs, torch_dtype) | ||
| if auto_model_module not in [AutoModelForCausalLM, AutoModel]: | ||
| model_kwargs2.pop("trust_remote_code", None) | ||
| model_kwargs2["dtype"] = torch_dtype | ||
| model_kwargs2.pop("max_memory", None) | ||
| model = from_config(hf_config, **model_kwargs2) | ||
| model = _from_config_for_empty_weights( | ||
| from_config, hf_config, model_kwargs2, torch_dtype | ||
| ) |
There was a problem hiding this comment.
intermixing of low level code and high level code. Why not move all these to _from_config_for_empty_weights
6c67f65 to
c133b8c
Compare
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
Superseded by #1857 from branch |
|
Summary
Fixes NVBug 6359821:
hf_ptq.pycan fail during the empty-weight device-map probe for remote/custom architectures likeDeciLMForCausalLMbecause the probe forwardsdtypeintofrom_config(), and that kwarg can leak to the custom model constructor.This change removes dtype-related kwargs from the temporary
from_config()call and instead sets PyTorch's default dtype only around the empty-weight construction used forinfer_auto_device_map.NVBug: https://nvbugspro.nvidia.com/bug/6359821
Validation
pre-commit run --files examples/hf_ptq/example_utils.py tests/examples/hf_ptq/test_example_utils.pypytest_pwd tests/examples/hf_ptq/test_example_utils.py(13 passed)