[JAX] Add support for per-layer config custimization and composable configs#2481
[JAX] Add support for per-layer config custimization and composable configs#2481bkowalskiINTEL wants to merge 23 commits into
Conversation
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Pull request overview
Adds JAX quantization support for (1) per-layer filtering via include/exclude patterns and (2) composing multiple quantization configs, while refactoring quantize flow to centralize model wrapping and improve (de)serialization behavior.
Changes:
- Add
include/excludelayer filters toStaticQuantConfigandDynamicQuantConfig, and apply them when generating model info. - Add ComposableConfig support in quantization config JSON serialization/deserialization and in
quantize_model()mapping construction. - Refactor JAX static/dynamic algorithms to operate on per-layer config mappings and move wrapper application into
quantize_model().
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
test/jax/test_config_on_vit.py |
Adds a ViT config-filter demonstration (currently executes at import time). |
neural_compressor/jax/quantization/saving.py |
Adds ComposableConfig (de)serialization and updates deserialization preparation to handle multiple sub-configs and include/exclude. |
neural_compressor/jax/quantization/quantize.py |
Adds ComposableConfig mapping merge logic, enforces static-first algorithm application, and centralizes model wrapping. |
neural_compressor/jax/quantization/config.py |
Introduces include/exclude filtering and switches model info to use layer paths (with filtering). |
neural_compressor/jax/algorithms/static.py |
Updates static quantization to prepare only layers selected by configs_mapping and to support per-layer params. |
neural_compressor/jax/algorithms/dynamic.py |
Updates dynamic quantization to prepare only layers selected by configs_mapping and to support per-layer params. |
examples/jax/keras/gemma/quantization.py |
Updates Gemma quantization example (currently contains an early exit()). |
| def _matches(pattern: str) -> bool: | ||
| if pattern == class_name: | ||
| return True | ||
| return re.search(pattern, layer_id) is not None |
| # Build configs_mapping - handle ComposableConfig by calling sub-configs individually | ||
| if isinstance(quant_config, ComposableConfig): | ||
| configs_mapping = _build_configs_mapping_composable(model, quant_config) | ||
| else: | ||
| model_info = quant_config.get_model_info(model) | ||
| configs_mapping = quant_config.to_config_mapping(model_info=model_info) |
anko-intel
left a comment
There was a problem hiding this comment.
Some comments from offline reviewe
|
|
||
| Returns: | ||
| keras.Model: The quantized model wrapped for inference. | ||
| keras.Model: The quantized model. |
There was a problem hiding this comment.
Previous comment seems to be more accurate
| causal_lm_make_replace_generate_function(model) | ||
|
|
||
| # Execute algorithms - static first to ensure calibration runs on original FP32 model | ||
| algo_order = sorted(algos_mapping.keys(), key=lambda name: (0 if name == STATIC_QUANT else 1)) |
There was a problem hiding this comment.
can we use priority or something like that. I am not sure, but I think I saw such list of algos
| iterate_over_layers(qmodel, operations, filter_function=lambda c: c in static_quant_mapping) | ||
| # Phase 1: Prepare layers and add observers | ||
| for layer in qmodel._flatten_layers(): | ||
| if layer.__class__ not in static_quant_mapping: |
There was a problem hiding this comment.
maybe we can filter if class could be quantied in earlier stage
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
2c557c7 to
c45031b
Compare
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Co-authored-by: Bartosz Kowalski <bartosz.kowalski@intel.com> Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Fixes missing import Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
| if quant_type == "composable": | ||
| sub_configs = [quant_config_from_json_object(cfg) for cfg in json_obj["configs"]] | ||
| result = sub_configs[0] | ||
| for cfg in sub_configs[1:]: | ||
| result = result + cfg | ||
| return result |
| def _matches(pattern: str) -> bool: | ||
| if pattern == class_name: | ||
| return True | ||
| return re.search(pattern, layer_id) is not None |
| white_list = self.white_list | ||
| if white_list is None: | ||
| white_list = [] | ||
| elif white_list == DEFAULT_WHITE_LIST: | ||
| from neural_compressor.jax.quantization.layers_dynamic import dynamic_quant_mapping | ||
|
|
||
| white_list = [layer_class.__name__ for layer_class in dynamic_quant_mapping.keys()] | ||
| filter_result = [] |
| white_list = self.white_list | ||
| if white_list is None: | ||
| white_list = [] | ||
| elif white_list == DEFAULT_WHITE_LIST: | ||
| from neural_compressor.jax.quantization.layers_static import static_quant_mapping | ||
|
|
||
| white_list = [layer_class.__name__ for layer_class in static_quant_mapping.keys()] | ||
| filter_result = [] |
| include (Optional[List[str]]): List of layer class names or path patterns to include. | ||
| When set, only matching layers are quantized. Supports fnmatch patterns. | ||
| exclude (Optional[List[str]]): List of layer class names or path patterns to exclude. | ||
| Matching layers are skipped. Supports fnmatch patterns. |
| lambda layer: layer.add_observers(), | ||
| ] | ||
| iterate_over_layers(qmodel, operations, filter_function=lambda c: c in static_quant_mapping) | ||
| calib_function(qmodel) |
| # Build configs_mapping - handle ComposableConfig by calling sub-configs individually | ||
| if isinstance(quant_config, ComposableConfig): | ||
| configs_mapping = _build_configs_mapping_composable(model, quant_config) | ||
| else: | ||
| model_info = quant_config.get_model_info(model) |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
No description provided.