Skip to content

[JAX] Add support for per-layer config custimization and composable configs#2481

Open
bkowalskiINTEL wants to merge 23 commits into
mainfrom
dev/bkowalsk/jax_composable_configs
Open

[JAX] Add support for per-layer config custimization and composable configs#2481
bkowalskiINTEL wants to merge 23 commits into
mainfrom
dev/bkowalsk/jax_composable_configs

Conversation

@bkowalskiINTEL

Copy link
Copy Markdown
Contributor

No description provided.

bkowalskiINTEL and others added 9 commits May 27, 2026 08:08
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Comment thread examples/jax/keras/gemma/quantization.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds JAX quantization support for (1) per-layer filtering via include/exclude patterns and (2) composing multiple quantization configs, while refactoring quantize flow to centralize model wrapping and improve (de)serialization behavior.

Changes:

  • Add include / exclude layer filters to StaticQuantConfig and DynamicQuantConfig, and apply them when generating model info.
  • Add ComposableConfig support in quantization config JSON serialization/deserialization and in quantize_model() mapping construction.
  • Refactor JAX static/dynamic algorithms to operate on per-layer config mappings and move wrapper application into quantize_model().

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
test/jax/test_config_on_vit.py Adds a ViT config-filter demonstration (currently executes at import time).
neural_compressor/jax/quantization/saving.py Adds ComposableConfig (de)serialization and updates deserialization preparation to handle multiple sub-configs and include/exclude.
neural_compressor/jax/quantization/quantize.py Adds ComposableConfig mapping merge logic, enforces static-first algorithm application, and centralizes model wrapping.
neural_compressor/jax/quantization/config.py Introduces include/exclude filtering and switches model info to use layer paths (with filtering).
neural_compressor/jax/algorithms/static.py Updates static quantization to prepare only layers selected by configs_mapping and to support per-layer params.
neural_compressor/jax/algorithms/dynamic.py Updates dynamic quantization to prepare only layers selected by configs_mapping and to support per-layer params.
examples/jax/keras/gemma/quantization.py Updates Gemma quantization example (currently contains an early exit()).

Comment thread test/jax/test_config_on_vit.py Outdated
Comment thread neural_compressor/jax/quantization/quantize.py
Comment on lines +63 to +66
def _matches(pattern: str) -> bool:
if pattern == class_name:
return True
return re.search(pattern, layer_id) is not None
Comment thread neural_compressor/jax/quantization/config.py Outdated
Comment thread neural_compressor/jax/quantization/config.py Outdated
Comment thread neural_compressor/jax/quantization/saving.py
Comment on lines +101 to +106
# Build configs_mapping - handle ComposableConfig by calling sub-configs individually
if isinstance(quant_config, ComposableConfig):
configs_mapping = _build_configs_mapping_composable(model, quant_config)
else:
model_info = quant_config.get_model_info(model)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
Comment thread examples/jax/keras/gemma/quantization.py

@anko-intel anko-intel left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments from offline reviewe


Returns:
keras.Model: The quantized model wrapped for inference.
keras.Model: The quantized model.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous comment seems to be more accurate

causal_lm_make_replace_generate_function(model)

# Execute algorithms - static first to ensure calibration runs on original FP32 model
algo_order = sorted(algos_mapping.keys(), key=lambda name: (0 if name == STATIC_QUANT else 1))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use priority or something like that. I am not sure, but I think I saw such list of algos

Comment thread neural_compressor/jax/algorithms/static.py Outdated
iterate_over_layers(qmodel, operations, filter_function=lambda c: c in static_quant_mapping)
# Phase 1: Prepare layers and add observers
for layer in qmodel._flatten_layers():
if layer.__class__ not in static_quant_mapping:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can filter if class could be quantied in earlier stage

Comment thread neural_compressor/jax/algorithms/static.py
Comment thread neural_compressor/jax/algorithms/dynamic.py Outdated
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Comment thread neural_compressor/jax/algorithms/static.py Outdated
Comment thread neural_compressor/jax/algorithms/dynamic.py Outdated
Comment thread neural_compressor/jax/quantization/quantize.py Outdated
@bkowalskiINTEL bkowalskiINTEL force-pushed the dev/bkowalsk/jax_composable_configs branch from 2c557c7 to c45031b Compare June 3, 2026 14:03
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Co-authored-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Fixes missing import

Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Comment on lines +73 to +78
if quant_type == "composable":
sub_configs = [quant_config_from_json_object(cfg) for cfg in json_obj["configs"]]
result = sub_configs[0]
for cfg in sub_configs[1:]:
result = result + cfg
return result
Comment thread neural_compressor/jax/quantization/saving.py
Comment on lines +63 to +66
def _matches(pattern: str) -> bool:
if pattern == class_name:
return True
return re.search(pattern, layer_id) is not None
Comment on lines +216 to 223
white_list = self.white_list
if white_list is None:
white_list = []
elif white_list == DEFAULT_WHITE_LIST:
from neural_compressor.jax.quantization.layers_dynamic import dynamic_quant_mapping

white_list = [layer_class.__name__ for layer_class in dynamic_quant_mapping.keys()]
filter_result = []
Comment on lines +423 to 430
white_list = self.white_list
if white_list is None:
white_list = []
elif white_list == DEFAULT_WHITE_LIST:
from neural_compressor.jax.quantization.layers_static import static_quant_mapping

white_list = [layer_class.__name__ for layer_class in static_quant_mapping.keys()]
filter_result = []
Comment on lines +128 to +131
include (Optional[List[str]]): List of layer class names or path patterns to include.
When set, only matching layers are quantized. Supports fnmatch patterns.
exclude (Optional[List[str]]): List of layer class names or path patterns to exclude.
Matching layers are skipped. Supports fnmatch patterns.
Comment thread neural_compressor/jax/quantization/config.py Outdated
lambda layer: layer.add_observers(),
]
iterate_over_layers(qmodel, operations, filter_function=lambda c: c in static_quant_mapping)
calib_function(qmodel)
Comment on lines +97 to +101
# Build configs_mapping - handle ComposableConfig by calling sub-configs individually
if isinstance(quant_config, ComposableConfig):
configs_mapping = _build_configs_mapping_composable(model, quant_config)
else:
model_info = quant_config.get_model_info(model)
@bkowalskiINTEL bkowalskiINTEL marked this pull request as ready for review June 3, 2026 15:02
bkowalskiINTEL and others added 4 commits June 3, 2026 17:03
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Signed-off-by: Bartosz Kowalski <bartosz.kowalski@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants