From f86bbcf66744943258b6d396240089bccb555b0c Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 1 May 2026 09:52:21 +0000 Subject: [PATCH 1/6] Gemam4 as vision model for studio --- .../tests/test_transformers_version.py | 7 ++ studio/backend/tests/test_vision_cache.py | 60 +++++++++++ studio/backend/utils/models/model_config.py | 99 +++++++++++++++++-- studio/backend/utils/transformers_version.py | 5 +- 4 files changed, 164 insertions(+), 7 deletions(-) diff --git a/studio/backend/tests/test_transformers_version.py b/studio/backend/tests/test_transformers_version.py index c031c2fea3..081c96899e 100644 --- a/studio/backend/tests/test_transformers_version.py +++ b/studio/backend/tests/test_transformers_version.py @@ -221,6 +221,13 @@ def test_gemma4_model_type_only(self, tmp_path: Path): assert _check_config_needs_550(str(tmp_path)) is True + def test_gemma4audio_model_type(self, tmp_path: Path): + """Gemma 4 family model_type values should return True.""" + cfg = {"model_type": "gemma4audio"} + (tmp_path / "config.json").write_text(json.dumps(cfg)) + + assert _check_config_needs_550(str(tmp_path)) is True + def test_llama_architecture(self, tmp_path: Path): """config.json with LlamaForCausalLM should return False.""" cfg = {"architectures": ["LlamaForCausalLM"], "model_type": "llama"} diff --git a/studio/backend/tests/test_vision_cache.py b/studio/backend/tests/test_vision_cache.py index 9e7bbdd1fb..ab5b23dd84 100644 --- a/studio/backend/tests/test_vision_cache.py +++ b/studio/backend/tests/test_vision_cache.py @@ -35,6 +35,7 @@ from utils.models.model_config import ( is_vision_model, _is_vision_model_uncached, + _is_vision_model_subprocess, _vision_detection_cache, ) @@ -117,6 +118,65 @@ def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess) mock_subprocess.assert_called_once() assert _vision_detection_cache[("unsloth/Qwen3.5-2B", None)] is True + @patch("utils.models.model_config._load_raw_model_config") + @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) + @patch("utils.transformers_version.needs_transformers_5", return_value = True) + def test_subprocess_none_falls_back_to_gemma4_raw_config( + self, mock_needs_t5, mock_subprocess, mock_load_raw_config + ): + mock_load_raw_config.return_value = { + "architectures": ["Gemma4AudioForCausalLM"], + "model_type": "gemma4audio", + } + + assert is_vision_model("unsloth/gemma-4-E4B-it") is True + assert is_vision_model("unsloth/gemma-4-E4B-it") is True + + mock_subprocess.assert_called_once() + mock_load_raw_config.assert_called_once_with( + "unsloth/gemma-4-E4B-it", hf_token = None + ) + assert _vision_detection_cache[("unsloth/gemma-4-E4B-it", None)] is True + + @patch("utils.models.model_config._load_raw_model_config", return_value = None) + @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) + @patch("utils.transformers_version.needs_transformers_5", return_value = True) + def test_raw_config_transient_failure_not_cached( + self, mock_needs_t5, mock_subprocess, mock_load_raw_config + ): + assert is_vision_model("unsloth/gemma-4-E4B-it") is False + assert is_vision_model("unsloth/gemma-4-E4B-it") is False + + assert mock_subprocess.call_count == 2 + assert mock_load_raw_config.call_count == 2 + assert ("unsloth/gemma-4-E4B-it", None) not in _vision_detection_cache + + @patch("utils.models.model_config._load_raw_model_config") + @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) + @patch("utils.transformers_version.needs_transformers_5", return_value = True) + def test_non_vision_raw_config_false_cached( + self, mock_needs_t5, mock_subprocess, mock_load_raw_config + ): + mock_load_raw_config.return_value = { + "architectures": ["LlamaForCausalLM"], + "model_type": "llama", + } + + assert is_vision_model("org/text-only-needs-t5") is False + assert is_vision_model("org/text-only-needs-t5") is False + + mock_subprocess.assert_called_once() + mock_load_raw_config.assert_called_once() + assert _vision_detection_cache[("org/text-only-needs-t5", None)] is False + + @patch("utils.transformers_version._ensure_venv_t5_550_exists", return_value = False) + @patch("utils.models.model_config.subprocess.run") + def test_missing_t5_550_env_returns_none_without_subprocess( + self, mock_run, mock_ensure_t5 + ): + assert _is_vision_model_subprocess("unsloth/gemma-4-E4B-it") is None + mock_run.assert_not_called() + # --------------------------------------------------------------------------- # Exception handling — cache the False fallback diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index a2b0c90e59..0d2a3bebbc 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -495,13 +495,77 @@ def load_model_config( "internvl_chat", "cogvlm2", "minicpmv", + "gemma4", } +_AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"} # Pre-computed .venv_t5 paths and backend dir for subprocess version switching. # Vision check uses 5.5.0 (newest, recognizes all architectures). _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550") _BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) + +def _is_vlm_model_type(model_type: Optional[str]) -> bool: + return model_type is not None and any( + model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES + ) + + +def _load_raw_model_config( + model_name: str, hf_token: Optional[str] = None +) -> Optional[Dict[str, Any]]: + local_config = Path(normalize_path(model_name)).expanduser() / "config.json" + if is_local_path(model_name) and local_config.is_file(): + try: + return json.loads(local_config.read_text()) + except Exception as exc: + logger.warning( + "Could not read raw config.json for '%s': %s", model_name, exc + ) + return None + + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download( + repo_id = model_name, + filename = "config.json", + token = hf_token, + ) + return json.loads(Path(config_path).read_text()) + except Exception as exc: + logger.warning( + "Could not fetch raw config.json for '%s': %s", model_name, exc + ) + return None + + +def _is_vision_model_raw_config( + model_name: str, hf_token: Optional[str] = None +) -> Optional[bool]: + config = _load_raw_model_config(model_name, hf_token = hf_token) + if config is None: + return None + + model_type = config.get("model_type") + if model_type in _AUDIO_ONLY_MODEL_TYPES: + is_vlm = False + else: + architectures = config.get("architectures") or [] + is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures) + if not is_vlm: + is_vlm = _is_vlm_model_type(model_type) + + logger.info( + "Vision check (raw config) for '%s': model_type=%s, architectures=%s, " + "is_vision=%s", + model_name, + model_type, + config.get("architectures", []), + is_vlm, + ) + return is_vlm + # Inline script executed in a subprocess with transformers 5.x activated. # Receives model_name and token via argv, prints JSON result to stdout. _VISION_CHECK_SCRIPT = r""" @@ -539,8 +603,10 @@ def load_model_config( is_vlm = True if not is_vlm and hasattr(config, "model_type"): vlm_types = {"phi3_v","llava","llava_next","llava_onevision", - "internvl_chat","cogvlm2","minicpmv"} - if config.model_type in vlm_types: + "internvl_chat","cogvlm2","minicpmv","gemma4"} + if config.model_type and any( + config.model_type.startswith(vlm_type) for vlm_type in vlm_types + ): is_vlm = True model_type = getattr(config, "model_type", "unknown") @@ -570,6 +636,17 @@ def _is_vision_model_subprocess( token_arg = hf_token or "" try: + from utils.transformers_version import _ensure_venv_t5_550_exists + + if not _ensure_venv_t5_550_exists(): + logger.warning( + "Vision check subprocess cannot use transformers 5.5.0 for '%s': " + ".venv_t5_550 missing or incomplete at %s", + model_name, + _VENV_T5_DIR, + ) + return None + result = subprocess.run( [ sys.executable, @@ -717,16 +794,26 @@ def _is_vision_model_uncached( "Model '%s' needs transformers 5.x -- checking vision via subprocess", model_name, ) - return _is_vision_model_subprocess(model_name, hf_token = hf_token) + subprocess_result = _is_vision_model_subprocess( + model_name, hf_token = hf_token + ) + if subprocess_result is not None: + return subprocess_result + + logger.info( + "Vision subprocess for '%s' did not return a definitive result -- " + "falling back to raw config.json", + model_name, + ) + return _is_vision_model_raw_config(model_name, hf_token = hf_token) try: config = load_model_config(model_name, use_auth = True, token = hf_token) # Exclude audio-only models that share ForConditionalGeneration suffix # (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration) - _audio_only_model_types = {"csm", "whisper"} model_type = getattr(config, "model_type", None) - if model_type in _audio_only_model_types: + if model_type in _AUDIO_ONLY_MODEL_TYPES: return False # Check 1: Architecture class name patterns @@ -755,7 +842,7 @@ def _is_vision_model_uncached( # Check 5: Known VLM model_type values that may not match above checks if hasattr(config, "model_type"): - if config.model_type in _VLM_MODEL_TYPES: + if _is_vlm_model_type(config.model_type): logger.info( f"Model {model_name} detected as VLM: model_type={config.model_type}" ) diff --git a/studio/backend/utils/transformers_version.py b/studio/backend/utils/transformers_version.py index f36bdcd6e8..4531851555 100644 --- a/studio/backend/utils/transformers_version.py +++ b/studio/backend/utils/transformers_version.py @@ -279,7 +279,10 @@ def _check_cfg(cfg: dict) -> bool: archs = cfg.get("architectures", []) if any(a in _TRANSFORMERS_550_ARCHITECTURES for a in archs): return True - if cfg.get("model_type") in _TRANSFORMERS_550_MODEL_TYPES: + model_type = cfg.get("model_type") + if model_type is not None and any( + model_type.startswith(t) for t in _TRANSFORMERS_550_MODEL_TYPES + ): return True return False From d1b23e141fe6d60b24936df720f9c483dc25e33d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 1 May 2026 10:01:23 +0000 Subject: [PATCH 2/6] cleanup --- studio/backend/tests/test_vision_cache.py | 118 +++----------------- studio/backend/utils/models/model_config.py | 99 +--------------- 2 files changed, 22 insertions(+), 195 deletions(-) diff --git a/studio/backend/tests/test_vision_cache.py b/studio/backend/tests/test_vision_cache.py index ab5b23dd84..d6f0fd3d94 100644 --- a/studio/backend/tests/test_vision_cache.py +++ b/studio/backend/tests/test_vision_cache.py @@ -10,7 +10,6 @@ * Repeated calls for the same model hit the cache (no redundant work). * Different models each trigger their own detection. * Both True and False results are cached. -* The subprocess path (transformers 5.x models) is also cached. * Exceptions that fall back to False are cached. """ @@ -35,7 +34,6 @@ from utils.models.model_config import ( is_vision_model, _is_vision_model_uncached, - _is_vision_model_subprocess, _vision_detection_cache, ) @@ -96,88 +94,6 @@ def test_false_result_cached(self, mock_uncached): assert _vision_detection_cache[("org/text-only", None)] is False -# --------------------------------------------------------------------------- -# Subprocess path (transformers 5.x) caching -# --------------------------------------------------------------------------- - - -class TestVisionCacheSubprocessPath: - """Models needing transformers 5.x go through _is_vision_model_subprocess. - The cache should prevent the subprocess from being spawned more than once - per model per process.""" - - @patch("utils.models.model_config._is_vision_model_subprocess", return_value = True) - @patch("utils.transformers_version.needs_transformers_5", return_value = True) - def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess): - """Subprocess should only fire on the first call; second is cached.""" - # First call: goes through uncached → subprocess - assert is_vision_model("unsloth/Qwen3.5-2B") is True - # Second call: cache hit, no subprocess - assert is_vision_model("unsloth/Qwen3.5-2B") is True - - mock_subprocess.assert_called_once() - assert _vision_detection_cache[("unsloth/Qwen3.5-2B", None)] is True - - @patch("utils.models.model_config._load_raw_model_config") - @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) - @patch("utils.transformers_version.needs_transformers_5", return_value = True) - def test_subprocess_none_falls_back_to_gemma4_raw_config( - self, mock_needs_t5, mock_subprocess, mock_load_raw_config - ): - mock_load_raw_config.return_value = { - "architectures": ["Gemma4AudioForCausalLM"], - "model_type": "gemma4audio", - } - - assert is_vision_model("unsloth/gemma-4-E4B-it") is True - assert is_vision_model("unsloth/gemma-4-E4B-it") is True - - mock_subprocess.assert_called_once() - mock_load_raw_config.assert_called_once_with( - "unsloth/gemma-4-E4B-it", hf_token = None - ) - assert _vision_detection_cache[("unsloth/gemma-4-E4B-it", None)] is True - - @patch("utils.models.model_config._load_raw_model_config", return_value = None) - @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) - @patch("utils.transformers_version.needs_transformers_5", return_value = True) - def test_raw_config_transient_failure_not_cached( - self, mock_needs_t5, mock_subprocess, mock_load_raw_config - ): - assert is_vision_model("unsloth/gemma-4-E4B-it") is False - assert is_vision_model("unsloth/gemma-4-E4B-it") is False - - assert mock_subprocess.call_count == 2 - assert mock_load_raw_config.call_count == 2 - assert ("unsloth/gemma-4-E4B-it", None) not in _vision_detection_cache - - @patch("utils.models.model_config._load_raw_model_config") - @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) - @patch("utils.transformers_version.needs_transformers_5", return_value = True) - def test_non_vision_raw_config_false_cached( - self, mock_needs_t5, mock_subprocess, mock_load_raw_config - ): - mock_load_raw_config.return_value = { - "architectures": ["LlamaForCausalLM"], - "model_type": "llama", - } - - assert is_vision_model("org/text-only-needs-t5") is False - assert is_vision_model("org/text-only-needs-t5") is False - - mock_subprocess.assert_called_once() - mock_load_raw_config.assert_called_once() - assert _vision_detection_cache[("org/text-only-needs-t5", None)] is False - - @patch("utils.transformers_version._ensure_venv_t5_550_exists", return_value = False) - @patch("utils.models.model_config.subprocess.run") - def test_missing_t5_550_env_returns_none_without_subprocess( - self, mock_run, mock_ensure_t5 - ): - assert _is_vision_model_subprocess("unsloth/gemma-4-E4B-it") is None - mock_run.assert_not_called() - - # --------------------------------------------------------------------------- # Exception handling — cache the False fallback # --------------------------------------------------------------------------- @@ -193,8 +109,7 @@ class TestVisionCacheOnException: "utils.models.model_config.load_model_config", side_effect = ValueError("bad config"), ) - @patch("utils.transformers_version.needs_transformers_5", return_value = False) - def test_permanent_exception_result_cached(self, mock_needs_t5, mock_load_config): + def test_permanent_exception_result_cached(self, mock_load_config): """A permanent failure (ValueError / RepositoryNotFoundError / GatedRepoError / JSONDecodeError) should be caught, return False, and that False should be cached so subsequent calls don't retry. @@ -213,8 +128,7 @@ def test_permanent_exception_result_cached(self, mock_needs_t5, mock_load_config "utils.models.model_config.load_model_config", side_effect = OSError("network down"), ) - @patch("utils.transformers_version.needs_transformers_5", return_value = False) - def test_transient_exception_not_cached(self, mock_needs_t5, mock_load_config): + def test_transient_exception_not_cached(self, mock_load_config): """A transient failure (OSError, timeouts) should return None from _is_vision_model_uncached, surface as False to the caller, and NOT be cached, so the next call retries detection. This matches @@ -236,12 +150,10 @@ def test_transient_exception_not_cached(self, mock_needs_t5, mock_load_config): class TestVisionCacheDirectPath: - """For models that do NOT need transformers 5.x, the detection goes through - load_model_config directly. The cache must work the same way.""" + """Direct detection through load_model_config should cache results.""" - @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_direct_vlm_detection_cached(self, mock_load_config, mock_needs_t5): + def test_direct_vlm_detection_cached(self, mock_load_config): """A standard VLM detected via architecture suffix should be cached.""" cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist cfg.model_type = "gemma3" @@ -253,9 +165,8 @@ def test_direct_vlm_detection_cached(self, mock_load_config, mock_needs_t5): # load_model_config should only be called once mock_load_config.assert_called_once() - @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_direct_non_vlm_detection_cached(self, mock_load_config, mock_needs_t5): + def test_direct_non_vlm_detection_cached(self, mock_load_config): """A standard text model (no VLM indicators) should cache False.""" cfg = MagicMock(spec = []) # spec=[] means no attributes at all cfg.model_type = "llama" @@ -267,11 +178,8 @@ def test_direct_non_vlm_detection_cached(self, mock_load_config, mock_needs_t5): assert is_vision_model("meta-llama/Llama-3-8B") is False mock_load_config.assert_called_once() - @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_vision_config_attr_detected_and_cached( - self, mock_load_config, mock_needs_t5 - ): + def test_vision_config_attr_detected_and_cached(self, mock_load_config): """Models with vision_config (LLaVA, Qwen2-VL, etc.) should be cached as True.""" cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist cfg.model_type = "qwen2_vl" @@ -283,9 +191,19 @@ def test_vision_config_attr_detected_and_cached( assert is_vision_model("Qwen/Qwen2-VL-7B") is True mock_load_config.assert_called_once() - @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_audio_model_excluded_and_cached(self, mock_load_config, mock_needs_t5): + def test_model_type_prefix_detected_and_cached(self, mock_load_config): + cfg = MagicMock(spec = []) + cfg.model_type = "gemma4audio" + cfg.architectures = ["Gemma4AudioForCausalLM"] + mock_load_config.return_value = cfg + + assert is_vision_model("google/gemma-4-audio") is True + assert is_vision_model("google/gemma-4-audio") is True + mock_load_config.assert_called_once() + + @patch("utils.models.model_config.load_model_config") + def test_audio_model_excluded_and_cached(self, mock_load_config): """Audio-only models (csm, whisper) with ForConditionalGeneration should be excluded from VLM detection and cached as False.""" cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index 0d2a3bebbc..a396a0172c 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -497,7 +497,6 @@ def load_model_config( "minicpmv", "gemma4", } -_AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"} # Pre-computed .venv_t5 paths and backend dir for subprocess version switching. # Vision check uses 5.5.0 (newest, recognizes all architectures). @@ -511,61 +510,6 @@ def _is_vlm_model_type(model_type: Optional[str]) -> bool: ) -def _load_raw_model_config( - model_name: str, hf_token: Optional[str] = None -) -> Optional[Dict[str, Any]]: - local_config = Path(normalize_path(model_name)).expanduser() / "config.json" - if is_local_path(model_name) and local_config.is_file(): - try: - return json.loads(local_config.read_text()) - except Exception as exc: - logger.warning( - "Could not read raw config.json for '%s': %s", model_name, exc - ) - return None - - try: - from huggingface_hub import hf_hub_download - - config_path = hf_hub_download( - repo_id = model_name, - filename = "config.json", - token = hf_token, - ) - return json.loads(Path(config_path).read_text()) - except Exception as exc: - logger.warning( - "Could not fetch raw config.json for '%s': %s", model_name, exc - ) - return None - - -def _is_vision_model_raw_config( - model_name: str, hf_token: Optional[str] = None -) -> Optional[bool]: - config = _load_raw_model_config(model_name, hf_token = hf_token) - if config is None: - return None - - model_type = config.get("model_type") - if model_type in _AUDIO_ONLY_MODEL_TYPES: - is_vlm = False - else: - architectures = config.get("architectures") or [] - is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures) - if not is_vlm: - is_vlm = _is_vlm_model_type(model_type) - - logger.info( - "Vision check (raw config) for '%s': model_type=%s, architectures=%s, " - "is_vision=%s", - model_name, - model_type, - config.get("architectures", []), - is_vlm, - ) - return is_vlm - # Inline script executed in a subprocess with transformers 5.x activated. # Receives model_name and token via argv, prints JSON result to stdout. _VISION_CHECK_SCRIPT = r""" @@ -603,10 +547,8 @@ def _is_vision_model_raw_config( is_vlm = True if not is_vlm and hasattr(config, "model_type"): vlm_types = {"phi3_v","llava","llava_next","llava_onevision", - "internvl_chat","cogvlm2","minicpmv","gemma4"} - if config.model_type and any( - config.model_type.startswith(vlm_type) for vlm_type in vlm_types - ): + "internvl_chat","cogvlm2","minicpmv"} + if config.model_type in vlm_types: is_vlm = True model_type = getattr(config, "model_type", "unknown") @@ -636,17 +578,6 @@ def _is_vision_model_subprocess( token_arg = hf_token or "" try: - from utils.transformers_version import _ensure_venv_t5_550_exists - - if not _ensure_venv_t5_550_exists(): - logger.warning( - "Vision check subprocess cannot use transformers 5.5.0 for '%s': " - ".venv_t5_550 missing or incomplete at %s", - model_name, - _VENV_T5_DIR, - ) - return None - result = subprocess.run( [ sys.executable, @@ -784,36 +715,14 @@ def _is_vision_model_uncached( Do not call directly; use is_vision_model() instead. """ - # Models that need transformers 5.x must be checked in a subprocess - # because AutoConfig in the main process (transformers 4.57.x) doesn't - # recognize their architectures. - from utils.transformers_version import needs_transformers_5 - - if needs_transformers_5(model_name): - logger.info( - "Model '%s' needs transformers 5.x -- checking vision via subprocess", - model_name, - ) - subprocess_result = _is_vision_model_subprocess( - model_name, hf_token = hf_token - ) - if subprocess_result is not None: - return subprocess_result - - logger.info( - "Vision subprocess for '%s' did not return a definitive result -- " - "falling back to raw config.json", - model_name, - ) - return _is_vision_model_raw_config(model_name, hf_token = hf_token) - try: config = load_model_config(model_name, use_auth = True, token = hf_token) # Exclude audio-only models that share ForConditionalGeneration suffix # (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration) + _audio_only_model_types = {"csm", "whisper"} model_type = getattr(config, "model_type", None) - if model_type in _AUDIO_ONLY_MODEL_TYPES: + if model_type in _audio_only_model_types: return False # Check 1: Architecture class name patterns From 71ffe2bc656cce29e759a9e534149e7ff3740834 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 1 May 2026 10:17:26 +0000 Subject: [PATCH 3/6] Simplify and add fallback onto config --- .../tests/test_transformers_version.py | 7 -- studio/backend/tests/test_vision_cache.py | 66 ++++++++++++-- studio/backend/utils/models/model_config.py | 91 +++++++++++++------ studio/backend/utils/transformers_version.py | 5 +- 4 files changed, 121 insertions(+), 48 deletions(-) diff --git a/studio/backend/tests/test_transformers_version.py b/studio/backend/tests/test_transformers_version.py index 081c96899e..c031c2fea3 100644 --- a/studio/backend/tests/test_transformers_version.py +++ b/studio/backend/tests/test_transformers_version.py @@ -221,13 +221,6 @@ def test_gemma4_model_type_only(self, tmp_path: Path): assert _check_config_needs_550(str(tmp_path)) is True - def test_gemma4audio_model_type(self, tmp_path: Path): - """Gemma 4 family model_type values should return True.""" - cfg = {"model_type": "gemma4audio"} - (tmp_path / "config.json").write_text(json.dumps(cfg)) - - assert _check_config_needs_550(str(tmp_path)) is True - def test_llama_architecture(self, tmp_path: Path): """config.json with LlamaForCausalLM should return False.""" cfg = {"architectures": ["LlamaForCausalLM"], "model_type": "llama"} diff --git a/studio/backend/tests/test_vision_cache.py b/studio/backend/tests/test_vision_cache.py index d6f0fd3d94..4a71eafa91 100644 --- a/studio/backend/tests/test_vision_cache.py +++ b/studio/backend/tests/test_vision_cache.py @@ -10,6 +10,7 @@ * Repeated calls for the same model hit the cache (no redundant work). * Different models each trigger their own detection. * Both True and False results are cached. +* The subprocess path (transformers 5.x models) is also cached. * Exceptions that fall back to False are cached. """ @@ -94,6 +95,43 @@ def test_false_result_cached(self, mock_uncached): assert _vision_detection_cache[("org/text-only", None)] is False +# --------------------------------------------------------------------------- +# Subprocess path (transformers 5.x) caching +# --------------------------------------------------------------------------- + + +class TestVisionCacheSubprocessPath: + """Models needing transformers 5.x go through _is_vision_model_subprocess. + The cache should prevent the subprocess from being spawned more than once + per model per process.""" + + @patch("utils.models.model_config._is_vision_model_subprocess", return_value = True) + @patch("utils.transformers_version.needs_transformers_5", return_value = True) + def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess): + """Subprocess should only fire on the first call; second is cached.""" + # First call: goes through uncached → subprocess + assert is_vision_model("unsloth/Qwen3.5-2B") is True + # Second call: cache hit, no subprocess + assert is_vision_model("unsloth/Qwen3.5-2B") is True + + mock_subprocess.assert_called_once() + assert _vision_detection_cache[("unsloth/Qwen3.5-2B", None)] is True + + @patch("utils.models.model_config._raw_config_has_vision_config", return_value = True) + @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) + @patch("utils.transformers_version.needs_transformers_5", return_value = True) + def test_subprocess_none_falls_back_to_raw_vision_config( + self, mock_needs_t5, mock_subprocess, mock_raw_config + ): + assert is_vision_model("unsloth/gemma-4-E4B-it") is True + assert is_vision_model("unsloth/gemma-4-E4B-it") is True + + mock_subprocess.assert_called_once() + mock_raw_config.assert_called_once_with( + "unsloth/gemma-4-E4B-it", hf_token = None + ) + + # --------------------------------------------------------------------------- # Exception handling — cache the False fallback # --------------------------------------------------------------------------- @@ -109,7 +147,8 @@ class TestVisionCacheOnException: "utils.models.model_config.load_model_config", side_effect = ValueError("bad config"), ) - def test_permanent_exception_result_cached(self, mock_load_config): + @patch("utils.transformers_version.needs_transformers_5", return_value = False) + def test_permanent_exception_result_cached(self, mock_needs_t5, mock_load_config): """A permanent failure (ValueError / RepositoryNotFoundError / GatedRepoError / JSONDecodeError) should be caught, return False, and that False should be cached so subsequent calls don't retry. @@ -128,7 +167,8 @@ def test_permanent_exception_result_cached(self, mock_load_config): "utils.models.model_config.load_model_config", side_effect = OSError("network down"), ) - def test_transient_exception_not_cached(self, mock_load_config): + @patch("utils.transformers_version.needs_transformers_5", return_value = False) + def test_transient_exception_not_cached(self, mock_needs_t5, mock_load_config): """A transient failure (OSError, timeouts) should return None from _is_vision_model_uncached, surface as False to the caller, and NOT be cached, so the next call retries detection. This matches @@ -150,10 +190,12 @@ def test_transient_exception_not_cached(self, mock_load_config): class TestVisionCacheDirectPath: - """Direct detection through load_model_config should cache results.""" + """For models that do NOT need transformers 5.x, the detection goes through + load_model_config directly. The cache must work the same way.""" + @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_direct_vlm_detection_cached(self, mock_load_config): + def test_direct_vlm_detection_cached(self, mock_load_config, mock_needs_t5): """A standard VLM detected via architecture suffix should be cached.""" cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist cfg.model_type = "gemma3" @@ -165,8 +207,9 @@ def test_direct_vlm_detection_cached(self, mock_load_config): # load_model_config should only be called once mock_load_config.assert_called_once() + @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_direct_non_vlm_detection_cached(self, mock_load_config): + def test_direct_non_vlm_detection_cached(self, mock_load_config, mock_needs_t5): """A standard text model (no VLM indicators) should cache False.""" cfg = MagicMock(spec = []) # spec=[] means no attributes at all cfg.model_type = "llama" @@ -178,8 +221,11 @@ def test_direct_non_vlm_detection_cached(self, mock_load_config): assert is_vision_model("meta-llama/Llama-3-8B") is False mock_load_config.assert_called_once() + @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_vision_config_attr_detected_and_cached(self, mock_load_config): + def test_vision_config_attr_detected_and_cached( + self, mock_load_config, mock_needs_t5 + ): """Models with vision_config (LLaVA, Qwen2-VL, etc.) should be cached as True.""" cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist cfg.model_type = "qwen2_vl" @@ -191,8 +237,11 @@ def test_vision_config_attr_detected_and_cached(self, mock_load_config): assert is_vision_model("Qwen/Qwen2-VL-7B") is True mock_load_config.assert_called_once() + @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_model_type_prefix_detected_and_cached(self, mock_load_config): + def test_model_type_prefix_detected_and_cached( + self, mock_load_config, mock_needs_t5 + ): cfg = MagicMock(spec = []) cfg.model_type = "gemma4audio" cfg.architectures = ["Gemma4AudioForCausalLM"] @@ -202,8 +251,9 @@ def test_model_type_prefix_detected_and_cached(self, mock_load_config): assert is_vision_model("google/gemma-4-audio") is True mock_load_config.assert_called_once() + @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") - def test_audio_model_excluded_and_cached(self, mock_load_config): + def test_audio_model_excluded_and_cached(self, mock_load_config, mock_needs_t5): """Audio-only models (csm, whisper) with ForConditionalGeneration should be excluded from VLM detection and cached as False.""" cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index a396a0172c..9370f42295 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -504,12 +504,39 @@ def load_model_config( _BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) -def _is_vlm_model_type(model_type: Optional[str]) -> bool: +def _is_vlm(config) -> bool: + architectures = getattr(config, "architectures", None) or [] + if any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures): + return True + model_type = getattr(config, "model_type", None) return model_type is not None and any( model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES ) +def _raw_config_has_vision_config( + model_name: str, hf_token: Optional[str] = None +) -> Optional[bool]: + try: + if is_local_path(model_name): + config_path = Path(normalize_path(model_name)).expanduser() / "config.json" + else: + from huggingface_hub import hf_hub_download + + config_path = Path( + hf_hub_download( + repo_id = model_name, + filename = "config.json", + token = hf_token, + ) + ) + config = json.loads(config_path.read_text()) + return "vision_config" in config and bool(config["vision_config"]) + except Exception as exc: + logger.warning("Could not read config.json for '%s': %s", model_name, exc) + return None + + # Inline script executed in a subprocess with transformers 5.x activated. # Receives model_name and token via argv, prints JSON result to stdout. _VISION_CHECK_SCRIPT = r""" @@ -526,6 +553,20 @@ def _is_vlm_model_type(model_type: Optional[str]) -> bool: if backend_dir not in sys.path: sys.path.insert(0, backend_dir) +def _is_vlm(config): + vlm_types = {"phi3_v","llava","llava_next","llava_onevision", + "internvl_chat","cogvlm2","minicpmv","gemma4"} + architectures = getattr(config, "architectures", []) or [] + if any( + x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) + for x in architectures + ): + return True + model_type = getattr(config, "model_type", None) + return model_type is not None and any( + model_type.startswith(vlm_type) for vlm_type in vlm_types + ) + try: from transformers import AutoConfig kwargs = {"trust_remote_code": True} @@ -533,25 +574,15 @@ def _is_vlm_model_type(model_type: Optional[str]) -> bool: kwargs["token"] = token config = AutoConfig.from_pretrained(model_name, **kwargs) - is_vlm = False - if hasattr(config, "architectures"): - is_vlm = any( - x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) - for x in config.architectures - ) + is_vlm = _is_vlm(config) if not is_vlm and hasattr(config, "vision_config"): is_vlm = True if not is_vlm and hasattr(config, "img_processor"): is_vlm = True if not is_vlm and hasattr(config, "image_token_index"): is_vlm = True - if not is_vlm and hasattr(config, "model_type"): - vlm_types = {"phi3_v","llava","llava_next","llava_onevision", - "internvl_chat","cogvlm2","minicpmv"} - if config.model_type in vlm_types: - is_vlm = True - model_type = getattr(config, "model_type", "unknown") + model_type = getattr(config, "model_type", None) archs = getattr(config, "architectures", []) print(json.dumps({"is_vision": is_vlm, "model_type": model_type, "architectures": archs})) @@ -715,6 +746,21 @@ def _is_vision_model_uncached( Do not call directly; use is_vision_model() instead. """ + # Models that need transformers 5.x must be checked in a subprocess + # because AutoConfig in the main process (transformers 4.57.x) doesn't + # recognize their architectures. + from utils.transformers_version import needs_transformers_5 + + if needs_transformers_5(model_name): + logger.info( + "Model '%s' needs transformers 5.x -- checking vision via subprocess", + model_name, + ) + result = _is_vision_model_subprocess(model_name, hf_token = hf_token) + if result is not None: + return result + return _raw_config_has_vision_config(model_name, hf_token = hf_token) + try: config = load_model_config(model_name, use_auth = True, token = hf_token) @@ -725,14 +771,9 @@ def _is_vision_model_uncached( if model_type in _audio_only_model_types: return False - # Check 1: Architecture class name patterns - if hasattr(config, "architectures"): - is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) - if is_vlm: - logger.info( - f"Model {model_name} detected as VLM: architecture {config.architectures}" - ) - return True + if _is_vlm(config): + logger.info(f"Model {model_name} detected as VLM") + return True # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) if hasattr(config, "vision_config"): @@ -749,14 +790,6 @@ def _is_vision_model_uncached( logger.info(f"Model {model_name} detected as VLM: has image_token_index") return True - # Check 5: Known VLM model_type values that may not match above checks - if hasattr(config, "model_type"): - if _is_vlm_model_type(config.model_type): - logger.info( - f"Model {model_name} detected as VLM: model_type={config.model_type}" - ) - return True - return False except Exception as e: diff --git a/studio/backend/utils/transformers_version.py b/studio/backend/utils/transformers_version.py index 4531851555..f36bdcd6e8 100644 --- a/studio/backend/utils/transformers_version.py +++ b/studio/backend/utils/transformers_version.py @@ -279,10 +279,7 @@ def _check_cfg(cfg: dict) -> bool: archs = cfg.get("architectures", []) if any(a in _TRANSFORMERS_550_ARCHITECTURES for a in archs): return True - model_type = cfg.get("model_type") - if model_type is not None and any( - model_type.startswith(t) for t in _TRANSFORMERS_550_MODEL_TYPES - ): + if cfg.get("model_type") in _TRANSFORMERS_550_MODEL_TYPES: return True return False From 4ba6a252af62d9755403462a6dba2994f414622d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 1 May 2026 10:29:41 +0000 Subject: [PATCH 4/6] consolidate VLM checks --- studio/backend/utils/models/model_config.py | 50 +++++---------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index 9370f42295..2257af40ee 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -506,11 +506,16 @@ def load_model_config( def _is_vlm(config) -> bool: architectures = getattr(config, "architectures", None) or [] - if any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures): - return True model_type = getattr(config, "model_type", None) - return model_type is not None and any( - model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES + return ( + any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures) + or hasattr(config, "vision_config") + or hasattr(config, "img_processor") + or hasattr(config, "image_token_index") + or ( + model_type is not None + and any(model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES) + ) ) @@ -553,34 +558,16 @@ def _raw_config_has_vision_config( if backend_dir not in sys.path: sys.path.insert(0, backend_dir) -def _is_vlm(config): - vlm_types = {"phi3_v","llava","llava_next","llava_onevision", - "internvl_chat","cogvlm2","minicpmv","gemma4"} - architectures = getattr(config, "architectures", []) or [] - if any( - x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) - for x in architectures - ): - return True - model_type = getattr(config, "model_type", None) - return model_type is not None and any( - model_type.startswith(vlm_type) for vlm_type in vlm_types - ) - try: from transformers import AutoConfig + from utils.models.model_config import _is_vlm + kwargs = {"trust_remote_code": True} if token: kwargs["token"] = token config = AutoConfig.from_pretrained(model_name, **kwargs) is_vlm = _is_vlm(config) - if not is_vlm and hasattr(config, "vision_config"): - is_vlm = True - if not is_vlm and hasattr(config, "img_processor"): - is_vlm = True - if not is_vlm and hasattr(config, "image_token_index"): - is_vlm = True model_type = getattr(config, "model_type", None) archs = getattr(config, "architectures", []) @@ -775,21 +762,6 @@ def _is_vision_model_uncached( logger.info(f"Model {model_name} detected as VLM") return True - # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) - if hasattr(config, "vision_config"): - logger.info(f"Model {model_name} detected as VLM: has vision_config") - return True - - # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) - if hasattr(config, "img_processor"): - logger.info(f"Model {model_name} detected as VLM: has img_processor") - return True - - # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) - if hasattr(config, "image_token_index"): - logger.info(f"Model {model_name} detected as VLM: has image_token_index") - return True - return False except Exception as e: From be47403c6ae10f8d96b6f17c4c077944463d7e63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 May 2026 10:30:17 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/tests/test_vision_cache.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/studio/backend/tests/test_vision_cache.py b/studio/backend/tests/test_vision_cache.py index 4a71eafa91..97579b3b4d 100644 --- a/studio/backend/tests/test_vision_cache.py +++ b/studio/backend/tests/test_vision_cache.py @@ -127,9 +127,7 @@ def test_subprocess_none_falls_back_to_raw_vision_config( assert is_vision_model("unsloth/gemma-4-E4B-it") is True mock_subprocess.assert_called_once() - mock_raw_config.assert_called_once_with( - "unsloth/gemma-4-E4B-it", hf_token = None - ) + mock_raw_config.assert_called_once_with("unsloth/gemma-4-E4B-it", hf_token = None) # --------------------------------------------------------------------------- From 45a6612b81e825663d6d1bbbf962598b8db146a5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 May 2026 12:18:33 +0000 Subject: [PATCH 6/6] Split: keep only 1 file(s) --- studio/backend/utils/models/model_config.py | 105 ++++++++++---------- 1 file changed, 53 insertions(+), 52 deletions(-) diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index 88e0535da3..dc8dd08315 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -496,52 +496,14 @@ def load_model_config( "internvl_chat", "cogvlm2", "minicpmv", - "gemma4", } # Pre-computed .venv_t5 paths and backend dir for subprocess version switching. # Vision check uses 5.5.0 (newest, recognizes all architectures). -_VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550") -_BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) - - -def _is_vlm(config) -> bool: - architectures = getattr(config, "architectures", None) or [] - model_type = getattr(config, "model_type", None) - return ( - any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures) - or hasattr(config, "vision_config") - or hasattr(config, "img_processor") - or hasattr(config, "image_token_index") - or ( - model_type is not None - and any(model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES) - ) - ) - - -def _raw_config_has_vision_config( - model_name: str, hf_token: Optional[str] = None -) -> Optional[bool]: - try: - if is_local_path(model_name): - config_path = Path(normalize_path(model_name)).expanduser() / "config.json" - else: - from huggingface_hub import hf_hub_download - - config_path = Path( - hf_hub_download( - repo_id = model_name, - filename = "config.json", - token = hf_token, - ) - ) - config = json.loads(config_path.read_text()) - return "vision_config" in config and bool(config["vision_config"]) - except Exception as exc: - logger.warning("Could not read config.json for '%s': %s", model_name, exc) - return None +from utils.paths.storage_roots import studio_root as _studio_root # noqa: E402 +_VENV_T5_DIR = str(_studio_root() / ".venv_t5_550") +_BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) # Inline script executed in a subprocess with transformers 5.x activated. # Receives model_name and token via argv, prints JSON result to stdout. @@ -561,16 +523,30 @@ def _raw_config_has_vision_config( try: from transformers import AutoConfig - from utils.models.model_config import _is_vlm - kwargs = {"trust_remote_code": True} if token: kwargs["token"] = token config = AutoConfig.from_pretrained(model_name, **kwargs) - is_vlm = _is_vlm(config) - - model_type = getattr(config, "model_type", None) + is_vlm = False + if hasattr(config, "architectures"): + is_vlm = any( + x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) + for x in config.architectures + ) + if not is_vlm and hasattr(config, "vision_config"): + is_vlm = True + if not is_vlm and hasattr(config, "img_processor"): + is_vlm = True + if not is_vlm and hasattr(config, "image_token_index"): + is_vlm = True + if not is_vlm and hasattr(config, "model_type"): + vlm_types = {"phi3_v","llava","llava_next","llava_onevision", + "internvl_chat","cogvlm2","minicpmv"} + if config.model_type in vlm_types: + is_vlm = True + + model_type = getattr(config, "model_type", "unknown") archs = getattr(config, "architectures", []) print(json.dumps({"is_vision": is_vlm, "model_type": model_type, "architectures": archs})) @@ -745,10 +721,7 @@ def _is_vision_model_uncached( "Model '%s' needs transformers 5.x -- checking vision via subprocess", model_name, ) - result = _is_vision_model_subprocess(model_name, hf_token = hf_token) - if result is not None: - return result - return _raw_config_has_vision_config(model_name, hf_token = hf_token) + return _is_vision_model_subprocess(model_name, hf_token = hf_token) try: config = load_model_config(model_name, use_auth = True, token = hf_token) @@ -760,10 +733,38 @@ def _is_vision_model_uncached( if model_type in _audio_only_model_types: return False - if _is_vlm(config): - logger.info(f"Model {model_name} detected as VLM") + # Check 1: Architecture class name patterns + if hasattr(config, "architectures"): + is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) + if is_vlm: + logger.info( + f"Model {model_name} detected as VLM: architecture {config.architectures}" + ) + return True + + # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) + if hasattr(config, "vision_config"): + logger.info(f"Model {model_name} detected as VLM: has vision_config") return True + # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) + if hasattr(config, "img_processor"): + logger.info(f"Model {model_name} detected as VLM: has img_processor") + return True + + # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) + if hasattr(config, "image_token_index"): + logger.info(f"Model {model_name} detected as VLM: has image_token_index") + return True + + # Check 5: Known VLM model_type values that may not match above checks + if hasattr(config, "model_type"): + if config.model_type in _VLM_MODEL_TYPES: + logger.info( + f"Model {model_name} detected as VLM: model_type={config.model_type}" + ) + return True + return False except Exception as e: