-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor VLM detection in studio #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f86bbcf
d1b23e1
71ffe2b
4ba6a25
be47403
9e55909
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -496,13 +496,53 @@ def load_model_config( | |||||
| "internvl_chat", | ||||||
| "cogvlm2", | ||||||
| "minicpmv", | ||||||
| "gemma4", | ||||||
| } | ||||||
|
|
||||||
| # Pre-computed .venv_t5 paths and backend dir for subprocess version switching. | ||||||
| # Vision check uses 5.5.0 (newest, recognizes all architectures). | ||||||
| _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550") | ||||||
| _BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) | ||||||
|
|
||||||
|
|
||||||
| def _is_vlm(config) -> bool: | ||||||
| architectures = getattr(config, "architectures", None) or [] | ||||||
| model_type = getattr(config, "model_type", None) | ||||||
| return ( | ||||||
| any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures) | ||||||
| or hasattr(config, "vision_config") | ||||||
| or hasattr(config, "img_processor") | ||||||
| or hasattr(config, "image_token_index") | ||||||
| or ( | ||||||
| model_type is not None | ||||||
| and any(model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def _raw_config_has_vision_config( | ||||||
| model_name: str, hf_token: Optional[str] = None | ||||||
| ) -> Optional[bool]: | ||||||
| try: | ||||||
| if is_local_path(model_name): | ||||||
| config_path = Path(normalize_path(model_name)).expanduser() / "config.json" | ||||||
| else: | ||||||
| from huggingface_hub import hf_hub_download | ||||||
|
|
||||||
| config_path = Path( | ||||||
| hf_hub_download( | ||||||
| repo_id = model_name, | ||||||
| filename = "config.json", | ||||||
| token = hf_token, | ||||||
| ) | ||||||
| ) | ||||||
| config = json.loads(config_path.read_text()) | ||||||
| return "vision_config" in config and bool(config["vision_config"]) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of
Suggested change
|
||||||
| except Exception as exc: | ||||||
| logger.warning("Could not read config.json for '%s': %s", model_name, exc) | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| # Inline script executed in a subprocess with transformers 5.x activated. | ||||||
| # Receives model_name and token via argv, prints JSON result to stdout. | ||||||
| _VISION_CHECK_SCRIPT = r""" | ||||||
|
|
@@ -521,30 +561,16 @@ def load_model_config( | |||||
|
|
||||||
| try: | ||||||
| from transformers import AutoConfig | ||||||
| from utils.models.model_config import _is_vlm | ||||||
|
|
||||||
| kwargs = {"trust_remote_code": True} | ||||||
| if token: | ||||||
| kwargs["token"] = token | ||||||
| config = AutoConfig.from_pretrained(model_name, **kwargs) | ||||||
|
|
||||||
| is_vlm = False | ||||||
| if hasattr(config, "architectures"): | ||||||
| is_vlm = any( | ||||||
| x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) | ||||||
| for x in config.architectures | ||||||
| ) | ||||||
| if not is_vlm and hasattr(config, "vision_config"): | ||||||
| is_vlm = True | ||||||
| if not is_vlm and hasattr(config, "img_processor"): | ||||||
| is_vlm = True | ||||||
| if not is_vlm and hasattr(config, "image_token_index"): | ||||||
| is_vlm = True | ||||||
| if not is_vlm and hasattr(config, "model_type"): | ||||||
| vlm_types = {"phi3_v","llava","llava_next","llava_onevision", | ||||||
| "internvl_chat","cogvlm2","minicpmv"} | ||||||
| if config.model_type in vlm_types: | ||||||
| is_vlm = True | ||||||
|
|
||||||
| model_type = getattr(config, "model_type", "unknown") | ||||||
| is_vlm = _is_vlm(config) | ||||||
|
|
||||||
| model_type = getattr(config, "model_type", None) | ||||||
| archs = getattr(config, "architectures", []) | ||||||
| print(json.dumps({"is_vision": is_vlm, "model_type": model_type, | ||||||
| "architectures": archs})) | ||||||
|
|
@@ -719,7 +745,10 @@ def _is_vision_model_uncached( | |||||
| "Model '%s' needs transformers 5.x -- checking vision via subprocess", | ||||||
| model_name, | ||||||
| ) | ||||||
| return _is_vision_model_subprocess(model_name, hf_token = hf_token) | ||||||
| result = _is_vision_model_subprocess(model_name, hf_token = hf_token) | ||||||
| if result is not None: | ||||||
| return result | ||||||
| return _raw_config_has_vision_config(model_name, hf_token = hf_token) | ||||||
|
|
||||||
| try: | ||||||
| config = load_model_config(model_name, use_auth = True, token = hf_token) | ||||||
|
|
@@ -731,38 +760,10 @@ def _is_vision_model_uncached( | |||||
| if model_type in _audio_only_model_types: | ||||||
| return False | ||||||
|
|
||||||
| # Check 1: Architecture class name patterns | ||||||
| if hasattr(config, "architectures"): | ||||||
| is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) | ||||||
| if is_vlm: | ||||||
| logger.info( | ||||||
| f"Model {model_name} detected as VLM: architecture {config.architectures}" | ||||||
| ) | ||||||
| return True | ||||||
|
|
||||||
| # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) | ||||||
| if hasattr(config, "vision_config"): | ||||||
| logger.info(f"Model {model_name} detected as VLM: has vision_config") | ||||||
| if _is_vlm(config): | ||||||
| logger.info(f"Model {model_name} detected as VLM") | ||||||
| return True | ||||||
|
|
||||||
| # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) | ||||||
| if hasattr(config, "img_processor"): | ||||||
| logger.info(f"Model {model_name} detected as VLM: has img_processor") | ||||||
| return True | ||||||
|
|
||||||
| # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) | ||||||
| if hasattr(config, "image_token_index"): | ||||||
| logger.info(f"Model {model_name} detected as VLM: has image_token_index") | ||||||
| return True | ||||||
|
|
||||||
| # Check 5: Known VLM model_type values that may not match above checks | ||||||
| if hasattr(config, "model_type"): | ||||||
| if config.model_type in _VLM_MODEL_TYPES: | ||||||
| logger.info( | ||||||
| f"Model {model_name} detected as VLM: model_type={config.model_type}" | ||||||
| ) | ||||||
| return True | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
| except Exception as e: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
hasattron anAutoConfigobject can be misleading because it returnsTrueeven if the attribute value isNone. While this preserves the previous behavior, it's generally safer to check if the attribute exists and is notNoneto avoid false positives for models that might have these attributes initialized toNone.