Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions studio/backend/tests/test_vision_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess)
mock_subprocess.assert_called_once()
assert _vision_detection_cache[("unsloth/Qwen3.5-2B", None)] is True

@patch("utils.models.model_config._raw_config_has_vision_config", return_value = True)
@patch("utils.models.model_config._is_vision_model_subprocess", return_value = None)
@patch("utils.transformers_version.needs_transformers_5", return_value = True)
def test_subprocess_none_falls_back_to_raw_vision_config(
self, mock_needs_t5, mock_subprocess, mock_raw_config
):
assert is_vision_model("unsloth/gemma-4-E4B-it") is True
assert is_vision_model("unsloth/gemma-4-E4B-it") is True

mock_subprocess.assert_called_once()
mock_raw_config.assert_called_once_with("unsloth/gemma-4-E4B-it", hf_token = None)


# ---------------------------------------------------------------------------
# Exception handling — cache the False fallback
Expand Down Expand Up @@ -223,6 +235,20 @@ def test_vision_config_attr_detected_and_cached(
assert is_vision_model("Qwen/Qwen2-VL-7B") is True
mock_load_config.assert_called_once()

@patch("utils.transformers_version.needs_transformers_5", return_value = False)
@patch("utils.models.model_config.load_model_config")
def test_model_type_prefix_detected_and_cached(
self, mock_load_config, mock_needs_t5
):
cfg = MagicMock(spec = [])
cfg.model_type = "gemma4audio"
cfg.architectures = ["Gemma4AudioForCausalLM"]
mock_load_config.return_value = cfg

assert is_vision_model("google/gemma-4-audio") is True
assert is_vision_model("google/gemma-4-audio") is True
mock_load_config.assert_called_once()

@patch("utils.transformers_version.needs_transformers_5", return_value = False)
@patch("utils.models.model_config.load_model_config")
def test_audio_model_excluded_and_cached(self, mock_load_config, mock_needs_t5):
Expand Down
101 changes: 51 additions & 50 deletions studio/backend/utils/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,53 @@ def load_model_config(
"internvl_chat",
"cogvlm2",
"minicpmv",
"gemma4",
}

# Pre-computed .venv_t5 paths and backend dir for subprocess version switching.
# Vision check uses 5.5.0 (newest, recognizes all architectures).
_VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550")
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent)


def _is_vlm(config) -> bool:
architectures = getattr(config, "architectures", None) or []
model_type = getattr(config, "model_type", None)
return (
any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures)
or hasattr(config, "vision_config")
or hasattr(config, "img_processor")
or hasattr(config, "image_token_index")
Comment on lines +513 to +515

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Using hasattr on an AutoConfig object can be misleading because it returns True even if the attribute value is None. While this preserves the previous behavior, it's generally safer to check if the attribute exists and is not None to avoid false positives for models that might have these attributes initialized to None.

Suggested change
or hasattr(config, "vision_config")
or hasattr(config, "img_processor")
or hasattr(config, "image_token_index")
or getattr(config, "vision_config", None) is not None
or getattr(config, "img_processor", None) is not None
or getattr(config, "image_token_index", None) is not None

or (
model_type is not None
and any(model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The startswith method in Python can accept a tuple of strings directly. Using model_type.startswith(tuple(_VLM_MODEL_TYPES)) would be more efficient and idiomatic than using a generator expression with any().

Suggested change
and any(model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES)
and model_type.startswith(tuple(_VLM_MODEL_TYPES))

)
)


def _raw_config_has_vision_config(
model_name: str, hf_token: Optional[str] = None
) -> Optional[bool]:
try:
if is_local_path(model_name):
config_path = Path(normalize_path(model_name)).expanduser() / "config.json"
else:
from huggingface_hub import hf_hub_download

config_path = Path(
hf_hub_download(
repo_id = model_name,
filename = "config.json",
token = hf_token,
)
)
config = json.loads(config_path.read_text())
return "vision_config" in config and bool(config["vision_config"])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of bool(config["vision_config"]) might incorrectly return False if a model uses a default vision configuration represented as an empty dictionary {} in the config.json. A safer check is to verify that the key exists and its value is not None.

Suggested change
return "vision_config" in config and bool(config["vision_config"])
return config.get("vision_config") is not None

except Exception as exc:
logger.warning("Could not read config.json for '%s': %s", model_name, exc)
return None


# Inline script executed in a subprocess with transformers 5.x activated.
# Receives model_name and token via argv, prints JSON result to stdout.
_VISION_CHECK_SCRIPT = r"""
Expand All @@ -521,30 +561,16 @@ def load_model_config(

try:
from transformers import AutoConfig
from utils.models.model_config import _is_vlm

kwargs = {"trust_remote_code": True}
if token:
kwargs["token"] = token
config = AutoConfig.from_pretrained(model_name, **kwargs)

is_vlm = False
if hasattr(config, "architectures"):
is_vlm = any(
x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
for x in config.architectures
)
if not is_vlm and hasattr(config, "vision_config"):
is_vlm = True
if not is_vlm and hasattr(config, "img_processor"):
is_vlm = True
if not is_vlm and hasattr(config, "image_token_index"):
is_vlm = True
if not is_vlm and hasattr(config, "model_type"):
vlm_types = {"phi3_v","llava","llava_next","llava_onevision",
"internvl_chat","cogvlm2","minicpmv"}
if config.model_type in vlm_types:
is_vlm = True

model_type = getattr(config, "model_type", "unknown")
is_vlm = _is_vlm(config)

model_type = getattr(config, "model_type", None)
archs = getattr(config, "architectures", [])
print(json.dumps({"is_vision": is_vlm, "model_type": model_type,
"architectures": archs}))
Expand Down Expand Up @@ -719,7 +745,10 @@ def _is_vision_model_uncached(
"Model '%s' needs transformers 5.x -- checking vision via subprocess",
model_name,
)
return _is_vision_model_subprocess(model_name, hf_token = hf_token)
result = _is_vision_model_subprocess(model_name, hf_token = hf_token)
if result is not None:
return result
return _raw_config_has_vision_config(model_name, hf_token = hf_token)

try:
config = load_model_config(model_name, use_auth = True, token = hf_token)
Expand All @@ -731,38 +760,10 @@ def _is_vision_model_uncached(
if model_type in _audio_only_model_types:
return False

# Check 1: Architecture class name patterns
if hasattr(config, "architectures"):
is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures)
if is_vlm:
logger.info(
f"Model {model_name} detected as VLM: architecture {config.architectures}"
)
return True

# Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.)
if hasattr(config, "vision_config"):
logger.info(f"Model {model_name} detected as VLM: has vision_config")
if _is_vlm(config):
logger.info(f"Model {model_name} detected as VLM")
return True

# Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config)
if hasattr(config, "img_processor"):
logger.info(f"Model {model_name} detected as VLM: has img_processor")
return True

# Check 4: Has image_token_index (common in VLMs for image placeholder tokens)
if hasattr(config, "image_token_index"):
logger.info(f"Model {model_name} detected as VLM: has image_token_index")
return True

# Check 5: Known VLM model_type values that may not match above checks
if hasattr(config, "model_type"):
if config.model_type in _VLM_MODEL_TYPES:
logger.info(
f"Model {model_name} detected as VLM: model_type={config.model_type}"
)
return True

return False

except Exception as e:
Expand Down
Loading