Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ logical_axis_rules: [
['mlp', ['attn_dp', 'model']],
['embed', []],
['norm', []],
['layers', []],
['dense_layers', []],
['moe_layers', []],
# ==========================================
# Inference(Prefill, Decode, Cache)
# ==========================================
Expand Down
67 changes: 58 additions & 9 deletions src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""

import os
import types
from typing import Any, Sequence

from absl import app
Expand All @@ -40,6 +41,7 @@

from maxtext.utils import model_creation_utils
from maxtext.utils import max_logging
from maxtext.utils import lora_utils
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import Config
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
Expand Down Expand Up @@ -87,6 +89,8 @@ def decode_with_vllm(config: Config) -> None:
"debug_sharding": config.debug_sharding,
"prefuse_moe_weights": config.prefuse_moe_weights,
"scan_layers": config.scan_layers,
"enable_nnx": config.enable_nnx,
"pure_nnx_decoder": config.pure_nnx_decoder,
},
"sharding": {
"sharding_strategy": {
Expand Down Expand Up @@ -178,7 +182,18 @@ def decode_with_tunix(
mesh: The JAX mesh for parallelism.
"""
# Wrap the model for Tunix
tunix_model = TunixMaxTextAdapter(base_model=model)
use_no_op_mappings = False
if hasattr(config, "vllm_hf_overrides") and config.vllm_hf_overrides:
overrides = config.vllm_hf_overrides
if isinstance(overrides, str) and "MaxTextForCausalLM" in overrides:
use_no_op_mappings = True
elif isinstance(overrides, dict) and "MaxTextForCausalLM" in overrides.get("architectures", []):
use_no_op_mappings = True

tunix_model = TunixMaxTextAdapter(
base_model=model,
use_no_op_mappings=use_no_op_mappings,
)

# Load the tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -209,13 +224,48 @@ def decode_with_tunix(
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
)

# MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1].
top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0
top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1

rollout_vllm_additional_config = {
"maxtext_config": {
"model_name": config.model_name,
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
"debug_sharding": config.debug_sharding,
"prefuse_moe_weights": config.prefuse_moe_weights,
"scan_layers": config.scan_layers,
"enable_nnx": config.enable_nnx,
"pure_nnx_decoder": config.pure_nnx_decoder,
}
}

if config.lora.enable_lora:
rollout_vllm_additional_config["maxtext_config"]["lora"] = {
"enable_lora": config.lora.enable_lora,
"lora_restore_path": config.lora.lora_restore_path,
"lora_rank": config.lora.lora_rank,
"lora_alpha": config.lora.lora_alpha,
"lora_module_path": config.lora.lora_module_path,
}

# Create vLLM rollout for inference
rollout_config = base_rollout.RolloutConfig(
max_tokens_to_generate=max_tokens_to_generate,
max_prompt_length=max_prompt_length,
temperature=config.decode_sampling_temperature,
top_p=config.decode_sampling_nucleus_p,
top_k=config.decode_sampling_top_k,
top_p=top_p,
top_k=top_k,
rollout_vllm_model_version=config.tokenizer_path,
rollout_vllm_hbm_utilization=config.hbm_utilization_vllm,
rollout_vllm_init_with_random_weights=True,
rollout_vllm_tpu_backend_type="jax",
tensor_parallel_size=config.ici_tensor_parallelism if config.ici_tensor_parallelism > 0 else 1,
data_parallel_size=jax.device_count()
// (config.ici_tensor_parallelism if config.ici_tensor_parallelism > 0 else 1),
rollout_vllm_additional_config=rollout_vllm_additional_config,
rollout_vllm_kwargs={"hf_overrides": config.vllm_hf_overrides},
)
vllm_rollout = VllmRollout(
model=tunix_model,
Expand All @@ -225,12 +275,7 @@ def decode_with_tunix(
# other special formatting, which is not part of max_prompt_length.
cache_config_or_size=max_prompt_length + max_tokens_to_generate + 256,
mesh=mesh,
model_version=config.tokenizer_path,
hbm_utilization=0.8,
# Initialize vllm model with random weights to speed up bootstrap time.
# Actual model weights will be loaded later.
init_with_random_weights=True,
tpu_backend_type="jax",
rollout_config=rollout_config,
)

# Generate text
Expand All @@ -251,6 +296,10 @@ def main(argv: Sequence[str]) -> None:

if FLAGS.use_tunix:
maxtext_model, mesh = model_creation_utils.from_pretrained(config)
if config.lora.enable_lora:
maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config)
if config.lora.lora_restore_path:
lora_utils.restore_lora_from_path(types.SimpleNamespace(model=maxtext_model), config)
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
else:
decode_with_vllm(config)
Expand Down
43 changes: 40 additions & 3 deletions src/maxtext/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,47 @@ def to_hf_hook_fns(self):
return {}

def lora_to_hf_mappings(self):
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].lora_to_hf_mappings()
# Dynamically generate LoRA mappings from base model weights mappings
base_mappings = self.to_hf_mapping()
if not base_mappings:
return None

lora_mapping = {}
for maxtext_key, (hf_key, sharding_spec) in base_mappings.items():
segments = set(maxtext_key.split("."))
is_input_proj = any(p in segments for p in ["wi_0", "wi_1", "query", "key", "value", "wq_a", "wq_b", "wkv_a", "wkv_b"])
is_output_proj = any(p in segments for p in ["wo", "out"])

if not (is_input_proj or is_output_proj):
continue

# Derive MaxText LoRA keys
maxtext_lora_a = maxtext_key + "_lora_a"
maxtext_lora_b = maxtext_key + "_lora_b"

# Derive HF/vLLM LoRA keys
if hf_key.endswith(".kernel"):
hf_lora_a = hf_key.replace(".kernel", ".kernel_lora_a")
hf_lora_b = hf_key.replace(".kernel", ".kernel_lora_b")
elif hf_key.endswith(".weight"):
hf_lora_a = hf_key.replace(".weight", ".weight_lora_a")
hf_lora_b = hf_key.replace(".weight", ".weight_lora_b")
else:
hf_lora_a = hf_key + "_lora_a"
hf_lora_b = hf_key + "_lora_b"

# Derive sharding specifications for Qwix LoRA parameters
if is_input_proj:
sharding_a = (None, "layer", None) # Input -> Rank (unsharded)
sharding_b = sharding_spec # Rank -> Output (same as base)
else:
sharding_a = sharding_spec # Input -> Rank (same as base)
sharding_b = (None, "layer", None) # Rank -> Output (unsharded)

lora_mapping[maxtext_lora_a] = (hf_lora_a, sharding_a)
lora_mapping[maxtext_lora_b] = (hf_lora_b, sharding_b)

return None
return lora_mapping

def _generalize_maxtext_key(self, maxtext_key):
"""Generalizes the MaxText key to a common vLLM format."""
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
model name. This allows for easy extension to support new models.
"""
from maxtext.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
from maxtext.integration.tunix.weight_mapping.gemma3 import GEMMA3_VLLM_MAPPING
from maxtext.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING
from maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from maxtext.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING
Expand All @@ -35,6 +36,8 @@ def __getattr__(self, name):
return QWEN2_VLLM_MAPPING
elif name.startswith("qwen3"):
return QWEN3_VLLM_MAPPING
elif name.startswith("gemma3"):
return GEMMA3_VLLM_MAPPING
elif name.startswith("deepseek3"):
return DEEPSEEK_VLLM_MAPPING
elif name.startswith("gpt-oss"):
Expand Down
122 changes: 122 additions & 0 deletions src/maxtext/integration/tunix/weight_mapping/gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the weight mapping from MaxText's Gemma3 model to a vLLM-compatible format."""

from dataclasses import dataclass

import numpy as np


@dataclass
class GEMMA3_VLLM_MAPPING:
"""Mapping MaxText Gemma3 weights to vLLM's Gemma3 weights."""

@staticmethod
def to_hf_hook_fns():
"""Returns a dictionary of hook functions to be applied to MaxText weights."""

def scale_embedding(arr):
hidden_size = arr.shape[1]
normalizer = np.dtype(arr.dtype).type(hidden_size**0.5)
return arr / normalizer

return {
"base.token_embedder.embedding": scale_embedding,
}

@staticmethod
def to_hf_transpose_keys():
"""Returns a list of keys for weights that need to be transposed."""
return {}

@staticmethod
def lora_to_hf_mappings():
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights."""
return None

@staticmethod
def to_hf_mapping():
"""Mapping from MaxText model to HuggingFace vLLM model.

Returns:
A dictionary mapping MaxText parameter names to HuggingFace parameter names and sharding.
"""
return {
# Token embeddings - shard vocab dimension
"base.token_embedder.embedding": (
"model.language_model.embed_tokens.kernel",
("model", None),
),
# Final layer norm - no sharding needed
"base.decoder.decoder_norm.scale": (
"model.language_model.norm.scale",
(None,),
),
# Layer norms - no sharding needed
"base.decoder.layers.pre_self_attention_norm.scale": (
"model.language_model.layers.*.input_layernorm.scale",
(None, "layer"),
),
"base.decoder.layers.post_self_attention_norm.scale": (
"model.language_model.layers.*.post_attention_layernorm.scale",
(None, "layer"),
),
"base.decoder.layers.self_attention.query_norm.scale": (
"model.language_model.layers.*.self_attn.q_norm.scale",
(None, "layer"),
),
"base.decoder.layers.self_attention.key_norm.scale": (
"model.language_model.layers.*.self_attn.k_norm.scale",
(None, "layer"),
),
"base.decoder.layers.pre_ffw_norm.scale": (
"model.language_model.layers.*.pre_feedforward_layernorm.scale",
(None, "layer"),
),
"base.decoder.layers.post_ffw_norm.scale": (
"model.language_model.layers.*.post_feedforward_layernorm.scale",
(None, "layer"),
),
# MLP components - shard hidden dimensions
"base.decoder.layers.mlp.wi_0.kernel": (
"model.language_model.layers.*.mlp.gate_proj.kernel",
(None, "layer", "model"),
),
"base.decoder.layers.mlp.wi_1.kernel": (
"model.language_model.layers.*.mlp.up_proj.kernel",
(None, "layer", "model"),
),
"base.decoder.layers.mlp.wo.kernel": (
"model.language_model.layers.*.mlp.down_proj.kernel",
("model", "layer", None),
),
# Attention components - shard head dimensions
"base.decoder.layers.self_attention.query.kernel": (
"model.language_model.layers.*.self_attn.q_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.key.kernel": (
"model.language_model.layers.*.self_attn.k_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.value.kernel": (
"model.language_model.layers.*.self_attn.v_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.out.kernel": (
"model.language_model.layers.*.self_attn.o_proj.kernel",
("model", "layer", None, None),
),
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""vLLM adapter for MaxText models."""

import os
import types
import jax

from flax import nnx
Expand All @@ -28,6 +29,7 @@
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE
from maxtext.utils import max_logging
from maxtext.utils import model_creation_utils
from maxtext.utils import lora_utils


try:
Expand Down Expand Up @@ -319,8 +321,12 @@ def load_weights(self, rng_key: jax.Array) -> None:
if self.model is not None:
return

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
model = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
if self.maxtext_config.lora.enable_lora:
model = lora_utils.apply_lora_to_model(model, self.mesh, self.maxtext_config)
if self.maxtext_config.lora.lora_restore_path:
lora_utils.restore_lora_from_path(types.SimpleNamespace(model=model), self.maxtext_config)
self.model = nnx.data(model)
Loading
Loading