Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,62 @@ def __init__(self, **kwargs):
}
qwen3_vl_4b_config = PTConfig(**qwen3_vl_4b_dict)

qwen3_vl_2b_dict = {
"architectures": ["Qwen3VLForConditionalGeneration"],
"image_token_id": 151655,
"model_type": "qwen3_vl",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"pad_token_id": None,
"rms_norm_eps": 1e-06,
"rope_parameters": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_theta": 5000000,
"rope_type": "default",
},
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 151936,
},
"tie_word_embeddings": True,
"transformers_version": "4.57.0.dev0",
"video_token_id": 151656,
"vision_config": {
"deepstack_visual_indexes": [5, 11, 17],
"depth": 24,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1024,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "qwen3_vl_vision",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
}
qwen3_vl_2b_config = PTConfig(**qwen3_vl_2b_dict)


# {maxtext model name: hf model config}
HF_MODEL_CONFIGS = {
Expand Down Expand Up @@ -1669,6 +1725,7 @@ def __init__(self, **kwargs):
"qwen3-14b": qwen3_14b_config,
"qwen3-14b-base": qwen3_14b_config,
"qwen3-32b": qwen3_32b_config,
"qwen3-vl-2b": qwen3_vl_2b_config,
"qwen3-vl-4b": qwen3_vl_4b_config,
"llama3.1-8b": llama31_8b_config,
"llama3.1-8b-Instruct": llama31_8b_config,
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@ def QWEN3_VL_HF_WEIGHTS_TO_SHAPE(config):
"qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-vl-2b": QWEN3_VL_HF_WEIGHTS_TO_SHAPE,
"qwen3-vl-4b": QWEN3_VL_HF_WEIGHTS_TO_SHAPE,
"llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
"llama3.1-8b-Instruct": LLAMA31_HF_WEIGHTS_TO_SHAPE,
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3883,6 +3883,7 @@ def reshape_vision_attn_out(input_tensor, target_shape):
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-vl-2b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
Expand Down Expand Up @@ -3934,6 +3935,7 @@ def reshape_vision_attn_out(input_tensor, target_shape):
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-vl-2b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
Expand Down
56 changes: 56 additions & 0 deletions src/maxtext/configs/models/qwen3-vl-2b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model config for Qwen/Qwen3-VL-2B-Instruct

# Core Architectural Parameters
decoder_block: "qwen3"
base_emb_dim: 2048
base_mlp_dim: 6144
base_num_query_heads: 16
base_num_kv_heads: 8
base_num_decoder_layers: 28
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 151936
normalization_layer_epsilon: 1.0e-6
use_qk_norm: true
logits_via_embedding: true
normalize_embedding_logits: false

# RoPE Settings
rope_max_timescale: 5000000

# General Model Settings
enable_dropout: false

# Vision Encoder Configuration
# Based on HuggingFace AutoConfig for Qwen/Qwen3-VL-2B-Instruct
use_multimodal: true
image_size_for_vit: 768
hidden_size_for_vit: 1024
intermediate_size_for_vit: 4096
num_attention_heads_for_vit: 16
num_hidden_layers_for_vit: 24
num_channels_for_vit: 3
patch_size_for_vit: 16
temporal_patch_size_for_vit: 2
spatial_merge_size_for_vit: 2
out_hidden_size_for_vit: 2048
num_position_embeddings_for_vit: 2304
deepstack_visual_indexes_for_vit: [5, 11, 17]

# MRoPE Settings
use_mrope: true
mrope_section: [24, 20, 20]
2 changes: 2 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class ProfilerType(str, Enum):
"qwen3-30b-a3b",
"qwen3-30b-a3b-base",
"qwen3-480b-a35b",
"qwen3-vl-2b",
"qwen3-vl-4b",
"qwen3-next-80b-a3b",
"qwen3-omni-30b-a3b",
Expand Down Expand Up @@ -3156,6 +3157,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
"qwen3-vl-2b",
"qwen3-vl-4b",
"qwen3.5-35b-a3b",
"qwen3.5-397b-a17b",
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ def _apply_embedding(
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
"qwen3-vl-2b",
"qwen3-vl-4b",
"qwen3.5-35b-a3b",
"qwen3.5-397b-a17b",
Expand All @@ -743,7 +744,7 @@ def _apply_embedding(
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")

if video_embeddings is not None and cfg.use_multimodal:
if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=video_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _setup_vision_encoder_layers(self):
)
setattr(self, projector_name, qwen3_5_vision.Qwen3_5MoeVisionProjector(config=self.config, rngs=self.rngs))
return encoder_name, projector_name
elif self.config.model_name in ["qwen3-vl-4b"]:
elif self.config.model_name in ["qwen3-vl-4b", "qwen3-vl-2b"]:
from maxtext.models import qwen3_vl_vision # pylint: disable=import-outside-toplevel

encoder_name = "Qwen3VLVisionEncoder_0"
Expand Down
14 changes: 7 additions & 7 deletions src/maxtext/multimodal/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def preprocess_mm_data(config):

images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = preprocess_mm_data_llama4(images)
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel

processor_outputs = preprocess_mm_data_qwen3_omni(config)
Expand All @@ -68,7 +68,7 @@ def preprocess_image_for_training(image, config):
from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel

return preprocess_mm_data_llama4(image)
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel

return preprocess_mm_data_qwen3_omni_for_training(image, config)
Expand All @@ -90,7 +90,7 @@ def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | No
from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel

return get_image_offsets_llama4(processor_output)
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel

return get_mm_offsets_qwen3_omni(config, processor_output)
Expand All @@ -112,7 +112,7 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_pla
from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel

return reformat_prompt_llama4(prompt, image_placeholder, num_images)
elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel

return reformat_prompt_qwen3_omni(
Expand All @@ -137,7 +137,7 @@ def reformat_response(response, model_name):
elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]:
formatted_response = f"{response}<turn|>"
return formatted_response
elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
formatted_response = f"{response}<|im_end|>"
return formatted_response
else:
Expand All @@ -158,7 +158,7 @@ def prepare_text_for_image_fusion(tokens, config, processor_output=None):
from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel

return add_extra_tokens_for_images_llama4(tokens, processor_output)
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel

return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output)
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens, is_video: bool =
from maxtext.multimodal.processor_llama4 import LLAMA4_PATCH_TOKEN # pylint: disable=import-outside-toplevel

bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel

tokens = QwenTokens(config)
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/utils/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"qwen3-8b": "Qwen/Qwen3-8B",
"qwen3-14b": "Qwen/Qwen3-14B",
"qwen3-32b": "Qwen/Qwen3-32B",
"qwen3-vl-2b": "Qwen/Qwen3-VL-2B-Instruct",
"qwen3-vl-4b": "Qwen/Qwen3-VL-4B-Instruct",
"llama3.1-8b": "meta-llama/Llama-3.1-8B",
"llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
Expand Down
Loading