diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 3f9e4f5290..7de3228e64 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -87,6 +87,9 @@ logical_axis_rules: [ ['mlp', ['attn_dp', 'model']], ['embed', []], ['norm', []], + ['layers', []], + ['dense_layers', []], + ['moe_layers', []], # ========================================== # Inference(Prefill, Decode, Cache) # ========================================== diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index c2b1e5e5d2..5a82757e3d 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -30,6 +30,7 @@ """ import os +import types from typing import Any, Sequence from absl import app @@ -40,6 +41,7 @@ from maxtext.utils import model_creation_utils from maxtext.utils import max_logging +from maxtext.utils import lora_utils from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR from maxtext.common.common_types import Config from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter @@ -87,6 +89,8 @@ def decode_with_vllm(config: Config) -> None: "debug_sharding": config.debug_sharding, "prefuse_moe_weights": config.prefuse_moe_weights, "scan_layers": config.scan_layers, + "enable_nnx": config.enable_nnx, + "pure_nnx_decoder": config.pure_nnx_decoder, }, "sharding": { "sharding_strategy": { @@ -178,7 +182,18 @@ def decode_with_tunix( mesh: The JAX mesh for parallelism. """ # Wrap the model for Tunix - tunix_model = TunixMaxTextAdapter(base_model=model) + use_no_op_mappings = False + if hasattr(config, "vllm_hf_overrides") and config.vllm_hf_overrides: + overrides = config.vllm_hf_overrides + if isinstance(overrides, str) and "MaxTextForCausalLM" in overrides: + use_no_op_mappings = True + elif isinstance(overrides, dict) and "MaxTextForCausalLM" in overrides.get("architectures", []): + use_no_op_mappings = True + + tunix_model = TunixMaxTextAdapter( + base_model=model, + use_no_op_mappings=use_no_op_mappings, + ) # Load the tokenizer tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -209,13 +224,48 @@ def decode_with_tunix( f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})" ) + # MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1]. + top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0 + top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1 + + rollout_vllm_additional_config = { + "maxtext_config": { + "model_name": config.model_name, + "weight_dtype": "bfloat16", + "allow_split_physical_axes": True, + "debug_sharding": config.debug_sharding, + "prefuse_moe_weights": config.prefuse_moe_weights, + "scan_layers": config.scan_layers, + "enable_nnx": config.enable_nnx, + "pure_nnx_decoder": config.pure_nnx_decoder, + } + } + + if config.lora.enable_lora: + rollout_vllm_additional_config["maxtext_config"]["lora"] = { + "enable_lora": config.lora.enable_lora, + "lora_restore_path": config.lora.lora_restore_path, + "lora_rank": config.lora.lora_rank, + "lora_alpha": config.lora.lora_alpha, + "lora_module_path": config.lora.lora_module_path, + } + # Create vLLM rollout for inference rollout_config = base_rollout.RolloutConfig( max_tokens_to_generate=max_tokens_to_generate, max_prompt_length=max_prompt_length, temperature=config.decode_sampling_temperature, - top_p=config.decode_sampling_nucleus_p, - top_k=config.decode_sampling_top_k, + top_p=top_p, + top_k=top_k, + rollout_vllm_model_version=config.tokenizer_path, + rollout_vllm_hbm_utilization=config.hbm_utilization_vllm, + rollout_vllm_init_with_random_weights=True, + rollout_vllm_tpu_backend_type="jax", + tensor_parallel_size=config.ici_tensor_parallelism if config.ici_tensor_parallelism > 0 else 1, + data_parallel_size=jax.device_count() + // (config.ici_tensor_parallelism if config.ici_tensor_parallelism > 0 else 1), + rollout_vllm_additional_config=rollout_vllm_additional_config, + rollout_vllm_kwargs={"hf_overrides": config.vllm_hf_overrides}, ) vllm_rollout = VllmRollout( model=tunix_model, @@ -225,12 +275,7 @@ def decode_with_tunix( # other special formatting, which is not part of max_prompt_length. cache_config_or_size=max_prompt_length + max_tokens_to_generate + 256, mesh=mesh, - model_version=config.tokenizer_path, - hbm_utilization=0.8, - # Initialize vllm model with random weights to speed up bootstrap time. - # Actual model weights will be loaded later. - init_with_random_weights=True, - tpu_backend_type="jax", + rollout_config=rollout_config, ) # Generate text @@ -251,6 +296,10 @@ def main(argv: Sequence[str]) -> None: if FLAGS.use_tunix: maxtext_model, mesh = model_creation_utils.from_pretrained(config) + if config.lora.enable_lora: + maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config) + if config.lora.lora_restore_path: + lora_utils.restore_lora_from_path(types.SimpleNamespace(model=maxtext_model), config) decode_with_tunix(config, model=maxtext_model, mesh=mesh) else: decode_with_vllm(config) diff --git a/src/maxtext/integration/tunix/utils.py b/src/maxtext/integration/tunix/utils.py index 8d608956bc..d934696973 100644 --- a/src/maxtext/integration/tunix/utils.py +++ b/src/maxtext/integration/tunix/utils.py @@ -153,10 +153,47 @@ def to_hf_hook_fns(self): return {} def lora_to_hf_mappings(self): - if self.use_standalone_mappings: - return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].lora_to_hf_mappings() + # Dynamically generate LoRA mappings from base model weights mappings + base_mappings = self.to_hf_mapping() + if not base_mappings: + return None + + lora_mapping = {} + for maxtext_key, (hf_key, sharding_spec) in base_mappings.items(): + segments = set(maxtext_key.split(".")) + is_input_proj = any(p in segments for p in ["wi_0", "wi_1", "query", "key", "value", "wq_a", "wq_b", "wkv_a", "wkv_b"]) + is_output_proj = any(p in segments for p in ["wo", "out"]) + + if not (is_input_proj or is_output_proj): + continue + + # Derive MaxText LoRA keys + maxtext_lora_a = maxtext_key + "_lora_a" + maxtext_lora_b = maxtext_key + "_lora_b" + + # Derive HF/vLLM LoRA keys + if hf_key.endswith(".kernel"): + hf_lora_a = hf_key.replace(".kernel", ".kernel_lora_a") + hf_lora_b = hf_key.replace(".kernel", ".kernel_lora_b") + elif hf_key.endswith(".weight"): + hf_lora_a = hf_key.replace(".weight", ".weight_lora_a") + hf_lora_b = hf_key.replace(".weight", ".weight_lora_b") + else: + hf_lora_a = hf_key + "_lora_a" + hf_lora_b = hf_key + "_lora_b" + + # Derive sharding specifications for Qwix LoRA parameters + if is_input_proj: + sharding_a = (None, "layer", None) # Input -> Rank (unsharded) + sharding_b = sharding_spec # Rank -> Output (same as base) + else: + sharding_a = sharding_spec # Input -> Rank (same as base) + sharding_b = (None, "layer", None) # Rank -> Output (unsharded) + + lora_mapping[maxtext_lora_a] = (hf_lora_a, sharding_a) + lora_mapping[maxtext_lora_b] = (hf_lora_b, sharding_b) - return None + return lora_mapping def _generalize_maxtext_key(self, maxtext_key): """Generalizes the MaxText key to a common vLLM format.""" diff --git a/src/maxtext/integration/tunix/weight_mapping/__init__.py b/src/maxtext/integration/tunix/weight_mapping/__init__.py index 6fd7f35028..39ab12ff8f 100644 --- a/src/maxtext/integration/tunix/weight_mapping/__init__.py +++ b/src/maxtext/integration/tunix/weight_mapping/__init__.py @@ -19,6 +19,7 @@ model name. This allows for easy extension to support new models. """ from maxtext.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING +from maxtext.integration.tunix.weight_mapping.gemma3 import GEMMA3_VLLM_MAPPING from maxtext.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING from maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING from maxtext.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING @@ -35,6 +36,8 @@ def __getattr__(self, name): return QWEN2_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING + elif name.startswith("gemma3"): + return GEMMA3_VLLM_MAPPING elif name.startswith("deepseek3"): return DEEPSEEK_VLLM_MAPPING elif name.startswith("gpt-oss"): diff --git a/src/maxtext/integration/tunix/weight_mapping/gemma3.py b/src/maxtext/integration/tunix/weight_mapping/gemma3.py new file mode 100644 index 0000000000..2ac7cb2379 --- /dev/null +++ b/src/maxtext/integration/tunix/weight_mapping/gemma3.py @@ -0,0 +1,122 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the weight mapping from MaxText's Gemma3 model to a vLLM-compatible format.""" + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class GEMMA3_VLLM_MAPPING: + """Mapping MaxText Gemma3 weights to vLLM's Gemma3 weights.""" + + @staticmethod + def to_hf_hook_fns(): + """Returns a dictionary of hook functions to be applied to MaxText weights.""" + + def scale_embedding(arr): + hidden_size = arr.shape[1] + normalizer = np.dtype(arr.dtype).type(hidden_size**0.5) + return arr / normalizer + + return { + "base.token_embedder.embedding": scale_embedding, + } + + @staticmethod + def to_hf_transpose_keys(): + """Returns a list of keys for weights that need to be transposed.""" + return {} + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights.""" + return None + + @staticmethod + def to_hf_mapping(): + """Mapping from MaxText model to HuggingFace vLLM model. + + Returns: + A dictionary mapping MaxText parameter names to HuggingFace parameter names and sharding. + """ + return { + # Token embeddings - shard vocab dimension + "base.token_embedder.embedding": ( + "model.language_model.embed_tokens.kernel", + ("model", None), + ), + # Final layer norm - no sharding needed + "base.decoder.decoder_norm.scale": ( + "model.language_model.norm.scale", + (None,), + ), + # Layer norms - no sharding needed + "base.decoder.layers.pre_self_attention_norm.scale": ( + "model.language_model.layers.*.input_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.post_self_attention_norm.scale": ( + "model.language_model.layers.*.post_attention_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.self_attention.query_norm.scale": ( + "model.language_model.layers.*.self_attn.q_norm.scale", + (None, "layer"), + ), + "base.decoder.layers.self_attention.key_norm.scale": ( + "model.language_model.layers.*.self_attn.k_norm.scale", + (None, "layer"), + ), + "base.decoder.layers.pre_ffw_norm.scale": ( + "model.language_model.layers.*.pre_feedforward_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.post_ffw_norm.scale": ( + "model.language_model.layers.*.post_feedforward_layernorm.scale", + (None, "layer"), + ), + # MLP components - shard hidden dimensions + "base.decoder.layers.mlp.wi_0.kernel": ( + "model.language_model.layers.*.mlp.gate_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wi_1.kernel": ( + "model.language_model.layers.*.mlp.up_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wo.kernel": ( + "model.language_model.layers.*.mlp.down_proj.kernel", + ("model", "layer", None), + ), + # Attention components - shard head dimensions + "base.decoder.layers.self_attention.query.kernel": ( + "model.language_model.layers.*.self_attn.q_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.key.kernel": ( + "model.language_model.layers.*.self_attn.k_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.value.kernel": ( + "model.language_model.layers.*.self_attn.v_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.out.kernel": ( + "model.language_model.layers.*.self_attn.o_proj.kernel", + ("model", "layer", None, None), + ), + } diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 217806f887..0bbe027e08 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -15,6 +15,7 @@ """vLLM adapter for MaxText models.""" import os +import types import jax from flax import nnx @@ -28,6 +29,7 @@ from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE from maxtext.utils import max_logging from maxtext.utils import model_creation_utils +from maxtext.utils import lora_utils try: @@ -319,8 +321,12 @@ def load_weights(self, rng_key: jax.Array) -> None: if self.model is not None: return - with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): + with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): model = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) + if self.maxtext_config.lora.enable_lora: + model = lora_utils.apply_lora_to_model(model, self.mesh, self.maxtext_config) + if self.maxtext_config.lora.lora_restore_path: + lora_utils.restore_lora_from_path(types.SimpleNamespace(model=model), self.maxtext_config) self.model = nnx.data(model) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index de48b60830..54c3b764a2 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1215,6 +1215,8 @@ def __call__( bidirectional_mask, previous_chunk, slot, + kv_caches=kv_caches, + attention_metadata=attention_metadata, ) elif self.is_gemma4: y = self._apply_gemma4_scanned_blocks( @@ -1226,6 +1228,8 @@ def __call__( bidirectional_mask, previous_chunk, slot, + kv_caches=kv_caches, + attention_metadata=attention_metadata, ) else: scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1328,6 +1332,8 @@ def _apply_gemma3_scanned_blocks( bidirectional_mask, previous_chunk, slot, + kv_caches=None, + attention_metadata=None, ): """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" @@ -1339,10 +1345,30 @@ def _apply_gemma3_scanned_blocks( layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) layer_kwargs = {"bidirectional_mask": bidirectional_mask} + if attention_metadata is not None: + layer_kwargs["attention_metadata"] = attention_metadata # Apply the main scan over the full blocks if scan_length > 0: - y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + if kv_caches is not None: + grouped_kv_caches = [] + for i in range(scan_length): + start_idx = i * attention_pattern_length + end_idx = start_idx + attention_pattern_length + grouped_kv_caches.append(tuple(kv_caches[start_idx:end_idx])) + + y, self.layers, _ = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=scan_length, kv_caches_stacked=grouped_kv_caches, **layer_kwargs + ) + + for i in range(scan_length): + start_idx = i * attention_pattern_length + for offset, updated_item in enumerate(grouped_kv_caches[i]): + kv_caches[start_idx + offset] = updated_item + else: + y, self.layers, _ = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=scan_length, **layer_kwargs + ) # Apply any remaining layers that did not fit into a full scanned block num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length @@ -1350,17 +1376,38 @@ def _apply_gemma3_scanned_blocks( policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) - def pure_gemma_fn(graphdef, state_in, y_in): + remainder_kv = None + if kv_caches is not None: + start_idx = scan_length * attention_pattern_length + remainder_kv = tuple(kv_caches[start_idx : start_idx + num_remaining_layers]) + + def pure_gemma_fn(graphdef, state_in, y_in, kv_in): merged_layer = nnx.merge(graphdef, state_in) - out_y, _ = merged_layer(y_in, *layer_args, previous_chunk=previous_chunk, slot=slot, **layer_kwargs) - return out_y, nnx.state(merged_layer) + call_kwargs = dict(layer_kwargs) + if kv_in is not None: + call_kwargs["kv_cache"] = kv_in + call_kwargs["previous_chunk"] = previous_chunk + call_kwargs["slot"] = slot + out_res = merged_layer(y_in, *layer_args, **call_kwargs) + if isinstance(out_res, tuple): + out_y = out_res[0] + out_kv = out_res[1] if len(out_res) > 1 else None + else: + out_y = out_res + out_kv = None + return out_y, out_kv, nnx.state(merged_layer) checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) graphdef, state = nnx.split(self.layers_remainder) - y, new_state = checkpointed_gemma_fn(graphdef, state, y) + y, updated_remainder_kv, new_state = checkpointed_gemma_fn(graphdef, state, y, remainder_kv) nnx.update(self.layers_remainder, new_state) + if kv_caches is not None and updated_remainder_kv is not None: + start_idx = scan_length * attention_pattern_length + for offset, updated_item in enumerate(updated_remainder_kv): + kv_caches[start_idx + offset] = updated_item + return y def _apply_gemma4_scanned_blocks( @@ -1373,6 +1420,8 @@ def _apply_gemma4_scanned_blocks( bidirectional_mask, previous_chunk, slot, + kv_caches=None, + attention_metadata=None, ): """Applies Gemma4 scanned decoder blocks, handling main scan and remainders.""" @@ -1384,12 +1433,30 @@ def _apply_gemma4_scanned_blocks( layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) layer_kwargs = {"bidirectional_mask": bidirectional_mask} + if attention_metadata is not None: + layer_kwargs["attention_metadata"] = attention_metadata # Apply the main scan over the full blocks if scan_length > 0: - y, self.scanned_blocks, _ = self._apply_layers_sequentially( - self.scanned_blocks, y, *layer_args, length=scan_length, **layer_kwargs - ) + if kv_caches is not None: + grouped_kv_caches = [] + for i in range(scan_length): + start_idx = i * attention_pattern_length + end_idx = start_idx + attention_pattern_length + grouped_kv_caches.append(tuple(kv_caches[start_idx:end_idx])) + + y, self.scanned_blocks, _ = self._apply_layers_sequentially( + self.scanned_blocks, y, *layer_args, length=scan_length, kv_caches_stacked=grouped_kv_caches, **layer_kwargs + ) + + for i in range(scan_length): + start_idx = i * attention_pattern_length + for offset, updated_item in enumerate(grouped_kv_caches[i]): + kv_caches[start_idx + offset] = updated_item + else: + y, self.scanned_blocks, _ = self._apply_layers_sequentially( + self.scanned_blocks, y, *layer_args, length=scan_length, **layer_kwargs + ) # Apply any remaining layers that did not fit into a full scanned block num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length @@ -1397,17 +1464,38 @@ def _apply_gemma4_scanned_blocks( policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) - def pure_gemma_fn(graphdef, state_in, y_in): + remainder_kv = None + if kv_caches is not None: + start_idx = scan_length * attention_pattern_length + remainder_kv = tuple(kv_caches[start_idx : start_idx + num_remaining_layers]) + + def pure_gemma_fn(graphdef, state_in, y_in, kv_in): merged_layer = nnx.merge(graphdef, state_in) - out_y, _ = merged_layer(y_in, *layer_args, previous_chunk=previous_chunk, slot=slot, **layer_kwargs) - return out_y, nnx.state(merged_layer) + call_kwargs = dict(layer_kwargs) + if kv_in is not None: + call_kwargs["kv_cache"] = kv_in + call_kwargs["previous_chunk"] = previous_chunk + call_kwargs["slot"] = slot + out_res = merged_layer(y_in, *layer_args, **call_kwargs) + if isinstance(out_res, tuple): + out_y = out_res[0] + out_kv = out_res[1] if len(out_res) > 1 else None + else: + out_y = out_res + out_kv = None + return out_y, out_kv, nnx.state(merged_layer) checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) graphdef, state = nnx.split(self.layers_remainder) - y, new_state = checkpointed_gemma_fn(graphdef, state, y) + y, updated_remainder_kv, new_state = checkpointed_gemma_fn(graphdef, state, y, remainder_kv) nnx.update(self.layers_remainder, new_state) + if kv_caches is not None and updated_remainder_kv is not None: + start_idx = scan_length * attention_pattern_length + for offset, updated_item in enumerate(updated_remainder_kv): + kv_caches[start_idx + offset] = updated_item + return y diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index d29edd6e8e..9aabfef740 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -129,9 +129,7 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: if isinstance(v, variablelib.Variable): col_name = variablelib.variable_name_from_type(v.type) v = to_linen_var(v) - else: - raise ValueError(f"Cannot infer collection name from value: {v}") - linen_structured[(col_name, *kp)] = v + linen_structured[(col_name, *kp)] = v variables = nnx.traversals.unflatten_mapping(linen_structured) return variables diff --git a/src/maxtext/models/gemma3.py b/src/maxtext/models/gemma3.py index 92c46b96f7..0e43859ac1 100644 --- a/src/maxtext/models/gemma3.py +++ b/src/maxtext/models/gemma3.py @@ -259,6 +259,8 @@ def update_cache(cache, val): stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None + elif kv_cache is not None: + return layer_output, kv_cache elif cfg.scan_layers: return layer_output, None else: @@ -324,6 +326,8 @@ def __call__( page_state=None, previous_chunk=None, bidirectional_mask=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -331,8 +335,10 @@ def __call__( inputs = checkpoint_name(inputs, "decoder_layer_input") y = inputs + updated_kvs = [] for layer_id in range(self.num_of_layers): - y = getattr(self, f"layers_{layer_id}")( + current_kv = kv_cache[layer_id] if kv_cache is not None else None + y_and_kv = getattr(self, f"layers_{layer_id}")( y, decoder_segment_ids, decoder_positions, @@ -341,10 +347,21 @@ def __call__( previous_chunk=previous_chunk, slot=slot, bidirectional_mask=bidirectional_mask, + kv_cache=current_kv, + attention_metadata=attention_metadata, ) - if cfg.scan_layers: - y = y[0] - if cfg.scan_layers: + if isinstance(y_and_kv, tuple): + y = y_and_kv[0] + new_kv = y_and_kv[1] if len(y_and_kv) > 1 else None + else: + y = y_and_kv + new_kv = None + if kv_cache is not None: + updated_kvs.append(new_kv) + + if kv_cache is not None: + return y, tuple(updated_kvs) + elif cfg.scan_layers: return y, None else: return y diff --git a/src/maxtext/models/gemma4.py b/src/maxtext/models/gemma4.py index 626d2ff54c..5da97b1613 100644 --- a/src/maxtext/models/gemma4.py +++ b/src/maxtext/models/gemma4.py @@ -398,6 +398,8 @@ def update_cache(cache, val): stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None + elif kv_cache is not None: + return layer_output, kv_cache elif cfg.scan_layers: return layer_output, None else: @@ -464,13 +466,18 @@ def __call__( page_state=None, previous_chunk=None, bidirectional_mask=None, + kv_cache=None, + attention_metadata=None, ): + cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") y = inputs + updated_kvs = [] for layer_id in range(self.num_of_layers): - y, _ = getattr(self, f"layers_{layer_id}")( + current_kv = kv_cache[layer_id] if kv_cache is not None else None + y_and_kv = getattr(self, f"layers_{layer_id}")( y, decoder_segment_ids, decoder_positions, @@ -479,9 +486,24 @@ def __call__( previous_chunk=previous_chunk, slot=slot, bidirectional_mask=bidirectional_mask, + kv_cache=current_kv, + attention_metadata=attention_metadata, ) - - return y, None + if isinstance(y_and_kv, tuple): + y = y_and_kv[0] + new_kv = y_and_kv[1] if len(y_and_kv) > 1 else None + else: + y = y_and_kv + new_kv = None + if kv_cache is not None: + updated_kvs.append(new_kv) + + if kv_cache is not None: + return y, tuple(updated_kvs) + elif cfg.scan_layers: + return y, None + else: + return y Gemma4ScannableBlockToLinen = nnx_wrappers.to_linen_class( diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 75c7989d9f..8c9c71a7df 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -252,8 +252,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None): # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): trainer = MaxTextPeftTrainer(model, optimizer, tunix_config) - if mt_config.lora.lora_restore_path: - trainer = lora_utils.restore_lora_from_path(trainer, mt_config) + if mt_config.lora.lora_restore_path and getattr(trainer, "train_steps", 0) == 0: + lora_utils.restore_lora_from_path(trainer.model, mt_config) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index ba7d540dae..d8935e6488 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -532,18 +532,14 @@ def apply_lora_to_model( return lora_model -def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: +def restore_lora_from_path(model: Any, mt_config: pyconfig.HyperParameters) -> Any: """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" - lora_restore_path = mt_config.lora.lora_restore_path + if hasattr(model, "model"): + model = model.model - train_steps = getattr(trainer, "train_steps", 0) - if train_steps > 0: - max_logging.log( - f"PeftTrainer restored current run at step {train_steps}; " f"ignoring lora_restore_path '{lora_restore_path}'." - ) - return trainer + lora_restore_path = mt_config.lora.lora_restore_path - if not is_lora_enabled(trainer.model): + if not is_lora_enabled(model): lora_module_path = _get_lora_module_path(mt_config) if not mt_config.lora.enable_lora: raise ValueError( @@ -551,7 +547,7 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." ) - abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) + abstract_lora_params = nnx.state(model, nnx.LoRAParam) target_for_restore = jax.tree.map( lambda v: {"value": v.value}, @@ -607,6 +603,6 @@ def _map_to_state(path, variable): is_leaf=lambda n: isinstance(n, nnx.Variable), ) - nnx.update(trainer.model, abstract_lora_params) + nnx.update(model, abstract_lora_params) max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") - return trainer + return model diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh new file mode 100644 index 0000000000..dd82921a57 --- /dev/null +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Validates the Gemma3-4B LoRA pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run LoRA starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the LoRA run. +# 4. Convert the checkpoint produced by the LoRA run back to HuggingFace format. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M) +# bash test_gemma3_to_mt.sh $RUN_ID +# bash test_gemma3_lora.sh $RUN_ID + + +set -ex + +source /home/jackyf_google_com/maxtext/.venv/bin/activate +export PYTHONPATH=src:$PYTHONPATH +export JAX_PLATFORMS=tpu,cpu + + +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} +MODEL_NAME='gemma3-4b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 2: Run inference on the original checkpoint converted from Hugging Face +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.6 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True scan_layers=false + +# Step 3: Run LoRA on the converted checkpoint +python3 -m maxtext.trainers.post_train.sft.train_sft \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/lora \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + steps=5 scan_layers=true \ + model_name=${MODEL_NAME} \ + learning_rate=3e-6 \ + lora.enable_lora=True \ + lora.lora_rank=16 \ + lora.lora_alpha=32.0 \ + enable_nnx=True \ + pure_nnx_decoder=True \ + enable_single_controller=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False + + +# Step 4: Run inference on the checkpoint generated from the previous run +python3 -m maxtext.inference.vllm_decode \ + --use_tunix=True \ + model_name=${MODEL_NAME} \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + lora.enable_lora=True \ + lora.lora_restore_path=${BASE_OUTPUT_DIRECTORY}/lora/${run_id}/checkpoints/5/model_params \ + lora.lora_rank=16 \ + lora.lora_alpha=32.0 \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.6 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True \ + enable_nnx=True \ + pure_nnx_decoder=True \ + scan_layers=true + +# Step 5: Convert the checkpoint from MaxText format to Hugging Face format +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name=${MODEL_NAME} \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + lora.lora_restore_path=${BASE_OUTPUT_DIRECTORY}/lora/${run_id}/checkpoints/5/model_params \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \ + scan_layers=true \ + enable_nnx=True \ + pure_nnx_decoder=True diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py index fb5cd6f0b6..498d8a9f72 100644 --- a/tests/integration/setup_train_loop_nnx_test.py +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -26,6 +26,7 @@ import sys import unittest + from flax import nnx import jax import jax.numpy as jnp @@ -55,17 +56,13 @@ def _tiny_nnx_pyconfig(**overrides): "max_target_length": 128, "vocab_size": 256, "steps": 1, - "tokenizer_path": os.path.join( - MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2" - ), + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"), "enable_goodput_recording": False, "enable_checkpoint_cloud_logger": False, "monitor_goodput": False, } init_kwargs.update(overrides) - return pyconfig.initialize( - [sys.argv[0], get_test_config_path()], **init_kwargs - ) + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **init_kwargs) @pytest.mark.integration_test @@ -120,14 +117,10 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): whose structure matches state_mesh_shardings.model after the same split. """ config = _tiny_nnx_pyconfig() - *_, state_mesh_shardings, model, _, _, _, _, _, _, train_state = ( - setup_train_loop(config, recorder=None) - ) + *_, state_mesh_shardings, model, _, _, _, _, _, _, train_state = setup_train_loop(config, recorder=None) _, params, _ = nnx.split(train_state.model, nnx.Param, ...) - _, params_shardings, _ = nnx.split( - state_mesh_shardings.model, nnx.Param, ... - ) + _, params_shardings, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) # Same key-set after nnx.split — this is what setup_train_loop relies on at # train_utils.py:281-282 to pair state_params with state_mesh_shardings_params. @@ -141,7 +134,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): def test_pure_nnx_dpo_setup_materializes_reference_model(self): """With use_dpo=True the NNX init_state_fn materializes a frozen reference - model alongside the policy (train_utils.py:233-237). Both come from _create_model_partial() with the same init_weights_seed, so absent a step-0 checkpoint the reference starts bit-identical to the policy. @@ -149,8 +141,7 @@ def test_pure_nnx_dpo_setup_materializes_reference_model(self): Positive replacement for the removed test_pure_nnx_dpo_raises_not_implemented: NNX DPO is supported now, so setup_train_loop builds the reference instead - of - raising. + of raising. """ config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) *_, train_state = setup_train_loop(config, recorder=None) @@ -163,9 +154,7 @@ def test_pure_nnx_dpo_setup_materializes_reference_model(self): # Same param tree, identical values at init (same seed, no step-0 override). policy_leaves = jax.tree.leaves(nnx.state(train_state.model, nnx.Param)) - reference_leaves = jax.tree.leaves( - nnx.state(train_state.reference_model, nnx.Param) - ) + reference_leaves = jax.tree.leaves(nnx.state(train_state.reference_model, nnx.Param)) self.assertGreater(len(policy_leaves), 0) self.assertEqual(len(policy_leaves), len(reference_leaves)) for p, r in zip(policy_leaves, reference_leaves): diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py index 4e616ff2af..e5c67b179e 100644 --- a/tests/post_training/unit/lora_utils_test.py +++ b/tests/post_training/unit/lora_utils_test.py @@ -216,15 +216,11 @@ def test_restore_lora_from_path(self): model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) model = lora_utils.apply_lora_to_model(model, None, cfg) - trainer = mock.MagicMock() - trainer.model = model - trainer.train_steps = 0 - restored_state = nnx.state(model, nnx.LoRAParam) with mock.patch("orbax.checkpoint.PyTreeCheckpointer.restore", return_value=restored_state) as mock_restore: with mock.patch("flax.nnx.update") as mock_update: - lora_utils.restore_lora_from_path(trainer, cfg) + lora_utils.restore_lora_from_path(model, cfg) mock_restore.assert_called_once() args, kwargs = mock_restore.call_args self.assertEqual(args[0], "some/path") diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index 92adb4e921..bdbee338ee 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -261,8 +261,7 @@ def main(config, test_args): # pylint: disable=W0621 if config.lora.enable_lora: model = lora_utils.apply_lora_to_model(model, mesh, config) if config.lora.lora_restore_path: - mock_trainer = type("MockTrainer", (), {"model": model, "train_steps": 0}) - lora_utils.restore_lora_from_path(mock_trainer, config) + lora_utils.restore_lora_from_path(model, config) state = None else: model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) @@ -432,7 +431,7 @@ def main(config, test_args): # pylint: disable=W0621 max_logging.log(f"Loading HF model with dtype: {torch_dtype} (derived from config.dtype: {config.dtype})") hf_model = AutoModelForCausalLM.from_pretrained( - test_args.hf_model_path, dtype=torch_dtype, token=hf_token, trust_remote_code=test_args.trust_remote_code + test_args.hf_model_path, torch_dtype=torch_dtype, token=hf_token, trust_remote_code=test_args.trust_remote_code ) hf_lora_path = config.hf_lora_adapter_path if hf_lora_path: @@ -469,8 +468,7 @@ def main(config, test_args): # pylint: disable=W0621 if config.lora.enable_lora: maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config) if config.lora.lora_restore_path: - mock_trainer = type("MockTrainer", (), {"model": maxtext_model, "train_steps": 0}) - lora_utils.restore_lora_from_path(mock_trainer, config) + lora_utils.restore_lora_from_path(maxtext_model, config) maxtext_state = None else: maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)