From 00096a99e52b021367f73ef956ad6829e2358f9a Mon Sep 17 00:00:00 2001 From: "Gao, Qun" Date: Wed, 15 Apr 2026 21:22:29 +0000 Subject: [PATCH 1/2] Fix activation scale inf issue for const_weight and const_scale Signed-off-by: Gao, Qun --- .../jax/quantization/layers_static.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/neural_compressor/jax/quantization/layers_static.py b/neural_compressor/jax/quantization/layers_static.py index 01c86bb9bda..29cf3900d66 100644 --- a/neural_compressor/jax/quantization/layers_static.py +++ b/neural_compressor/jax/quantization/layers_static.py @@ -243,6 +243,18 @@ def post_quantization_cleanup(self): None: Cleans up observers and sets quantized call. """ self._tracker.unlock() + if not self._is_quantized: + # Clean up observer only if it exists + if hasattr(self, "input_observer"): + if hasattr(self, "_layers") and self.input_observer in self._layers: + self._layers.remove(self.input_observer) + # Set call to pass-through/original + if hasattr(self, 'call'): + # pass through + pass + self._const_variables = [] + self._tracker.lock() + return if hasattr(self, "_layers") and hasattr(self, "input_observer"): if self.input_observer in self._layers: self._layers.remove(self.input_observer) @@ -447,6 +459,16 @@ def post_quantization_cleanup(self): None: Cleans up observers and original weights. """ self._tracker.unlock() + if not self._is_quantized: + if hasattr(self, "input_observer"): + if hasattr(self, "_layers") and self.input_observer in self._layers: + self._layers.remove(self.input_observer) + # Set call to pass-through/original + if hasattr(self, 'call'): + pass + self._const_variables = [] + self._tracker.lock() + return if hasattr(self, "_kernel") and self._kernel in self._trainable_variables: self._trainable_variables.remove(self._kernel) del self._kernel From 4db8110ed78fd7e22c6a7f86e1d250409efcb210 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 22:03:30 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/jax/quantization/layers_static.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neural_compressor/jax/quantization/layers_static.py b/neural_compressor/jax/quantization/layers_static.py index 29cf3900d66..91d4e9a647b 100644 --- a/neural_compressor/jax/quantization/layers_static.py +++ b/neural_compressor/jax/quantization/layers_static.py @@ -249,12 +249,12 @@ def post_quantization_cleanup(self): if hasattr(self, "_layers") and self.input_observer in self._layers: self._layers.remove(self.input_observer) # Set call to pass-through/original - if hasattr(self, 'call'): + if hasattr(self, "call"): # pass through pass self._const_variables = [] self._tracker.lock() - return + return if hasattr(self, "_layers") and hasattr(self, "input_observer"): if self.input_observer in self._layers: self._layers.remove(self.input_observer) @@ -464,7 +464,7 @@ def post_quantization_cleanup(self): if hasattr(self, "_layers") and self.input_observer in self._layers: self._layers.remove(self.input_observer) # Set call to pass-through/original - if hasattr(self, 'call'): + if hasattr(self, "call"): pass self._const_variables = [] self._tracker.lock()