From c28a4a69b54d35c76b6df72558ddaa3384f35061 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 24 May 2026 13:53:17 +0000 Subject: [PATCH 01/12] Fix three notebook regressions caught by Blackwell Docker validation 1. Inductor subprocess GPU invisibility under cgroup-pinned containers. unsloth/_gpu_init.py forces TORCHINDUCTOR_COMPILE_THREADS=1 + UNSLOTH_FORCE_SINGLE_COMPILE_WORKER=1 when only NVIDIA_VISIBLE_DEVICES is set. patch_torch_compile must keep that env var instead of popping it during the non-debug branch; otherwise the compile worker pool respawns and the cgroup-pinned GPU is invisible. Plus the inductor options dict in temporary_patches.common is built from determine_compile_threads which previously ignored the env var, so honour the explicit single-worker forcing there too. 2. Gemma3Processor.__call__ ragged-batch crash on TRL GRPO paged path. _gemma3_call_impl receives `text=[...]` with no `padding=` kwarg from TRL's paged generation collate; upstream Gemma3ProcessorKwargs default is `padding=False` so BatchFeature blows up stacking variable-length input_ids. Force `padding="longest"` only when the caller did not pin padding AND there is more than one text row. Single-image inference path is byte-identical (text rows == 1 branch skipped). Repros: nb/Gemma3_(4B)-Vision-GRPO.ipynb, nb/Qwen3_VL_(8B)-Vision-GRPO.ipynb. 3. (deps in #1) Single-worker compile threads is also the right default for docker --gpus N on Blackwell since the Triton 3.6 driver in the subproc can't enumerate it; this also unbreaks Mistral CPT and gpt-oss-20B fine-tuning notebooks running inside the container. All three changes are gated by environment fingerprints or per-call kwarg detection and are forwards/backwards compatible with transformers 4.57.6 + 5.x and TRL 0.22.2 / 0.27.1 / 1.x. --- unsloth_zoo/patching_utils.py | 6 +++++- unsloth_zoo/temporary_patches/common.py | 6 ++++++ unsloth_zoo/temporary_patches/gemma.py | 13 +++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 0b3482447..e342244dd 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -110,7 +110,11 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): else: DEBUGGING = "" os.environ.pop("TORCHDYNAMO_VERBOSE", None) - os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) + # Preserve the single-worker forcing put in place by unsloth/_gpu_init + # to keep cgroup-pinned containers from spawning Inductor subprocess + # workers that can't see the GPU. + if os.environ.get("UNSLOTH_FORCE_SINGLE_COMPILE_WORKER", "0") != "1": + os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) os.environ.pop("TORCHINDUCTOR_FORCE_DISABLE_CACHES", None) os.environ.pop("TORCH_LOGS", None) torch._logging.set_logs(all = logging.CRITICAL) diff --git a/unsloth_zoo/temporary_patches/common.py b/unsloth_zoo/temporary_patches/common.py index a334bb625..1370d0df2 100644 --- a/unsloth_zoo/temporary_patches/common.py +++ b/unsloth_zoo/temporary_patches/common.py @@ -43,6 +43,12 @@ def determine_compile_threads(): # See https://github.com/pytorch/pytorch/blob/ab2294d8289a7757a2fc321cdefac88e2b378edf/torch/_inductor/config.py#L771 # Windows thread count = 1. See https://github.com/unslothai/unsloth-zoo/pull/187 if sys.platform == "win32": return 1 + # Honour the explicit single-worker forcing set by unsloth/_gpu_init for + # cgroup-pinned containers where the Inductor compile worker pool cannot + # see the GPU. Otherwise determine_compile_threads ignores the env var + # and the options dict still passes the multi-worker default. + if os.environ.get("TORCHINDUCTOR_COMPILE_THREADS") == "1": + return 1 cpu_count = os.cpu_count() return min(32, max(4, cpu_count)) pass diff --git a/unsloth_zoo/temporary_patches/gemma.py b/unsloth_zoo/temporary_patches/gemma.py index 9afb30def..65726a4c0 100644 --- a/unsloth_zoo/temporary_patches/gemma.py +++ b/unsloth_zoo/temporary_patches/gemma.py @@ -121,6 +121,19 @@ def _gemma3_call_impl( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + # TRL GRPO paged + reward paths call Gemma3Processor(text=[...]) with no + # padding= kwarg; upstream Gemma3ProcessorKwargs default is padding=False + # so ragged completions blow up BatchFeature tensor stacking. Force + # longest-padding only when caller did not pin padding AND we have >1 + # text row (single-image inference is byte-identical). + _user_padding = kwargs.get("padding", None) + if _user_padding is None: + _user_padding = kwargs.get("text_kwargs", {}).get("padding", None) + _text_rows = ( + len(text) if isinstance(text, (list, tuple)) and not isinstance(text, str) else 1 + ) + if _user_padding is None and _text_rows > 1: + output_kwargs["text_kwargs"]["padding"] = "longest" batched_images = None if images is not None: From 0f9153c29e8621276ac80f7ecd99b74b8d47f95c Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 24 May 2026 13:56:11 +0000 Subject: [PATCH 02/12] patch_vllm_graph_capture: bail out cleanly when vllm is not installed NameError on `vllm.__version__` at line 749 because `vllm` is only locally imported inside the v1 try/except above. Surfaced when patch_vllm runs in an image that ships without the vllm extra (arm64 wheels, CPU-only, SyntheticDataKit on no-vllm image). Add an early `try: import vllm` at the top of the function so the patch silently returns when vllm is missing. Repro: nb/Meta_Synthetic_Data_Llama3_2_(3B).ipynb on the no-vllm image. --- unsloth_zoo/vllm_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 4d77c88a5..016aeee49 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -708,6 +708,14 @@ def patch_vllm_graph_capture(): import time from functools import wraps + # vLLM may not be installed (e.g. arm64 wheels, CPU-only, or images + # built without the vllm extra). Bail out cleanly instead of raising + # NameError on the `vllm.__version__` check below. + try: + import vllm + except ImportError: + return + @contextmanager def suppress_gc_collect(): original_gc_collect = gc.collect From a5770babce4027dfad777aa9a7c892ece5214b02 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 24 May 2026 14:38:20 +0000 Subject: [PATCH 03/12] load_vllm: raise actionable ImportError when vllm is not installed The module-level `vllm_version` is only defined inside the `if importlib.util.find_spec("vllm") is not None` block near the top of vllm_utils.py. Functions that use it unconditionally (load_vllm, the standby/headroom paths at lines 1992-2274) raise NameError on the no-vllm image rather than a useful message. Short-circuit at the start of load_vllm with a clear ImportError so callers (SyntheticDataKit.from_pretrained, FastLanguageModel.from_pretrained with fast_inference=True) get an actionable hint instead of a stack trace. Repro: nb/Meta_Synthetic_Data_Llama3_2_(3B).ipynb on the no-vllm image. --- unsloth_zoo/vllm_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 016aeee49..f8b3362d6 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1779,6 +1779,17 @@ def load_vllm( assert(type(use_bitsandbytes) is bool) assert(conservativeness >= 0.0 and conservativeness <= 1.0) + # `vllm_version` only exists when `vllm` was importable at module-load + # time (see the find_spec guard near top of file). Raise an actionable + # error here instead of NameError, e.g. on the unsloth-blackwell + # no-vllm image or on arm64 where no vllm wheel is published. + if "vllm_version" not in globals(): + raise ImportError( + "vLLM is required for `load_vllm`/SyntheticDataKit/`fast_inference=True` " + "but it is not installed. Install it with `pip install vllm` (CUDA only; " + "no wheel exists for arm64/aarch64 as of this writing)." + ) + unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0") # This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful. standby_util_override = os.getenv("UNSLOTH_VLLM_STANDBY_UTIL_OVERRIDE", "0") != "0" From 62b70175995142563e7e5bba5b41fe6a14090559 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 24 May 2026 14:39:26 +0000 Subject: [PATCH 04/12] llama_cpp.install_package: handle EOFError on input() in non-TTY contexts Docker containers run with stdin redirected to /dev/null (or no -i flag) raise EOFError when `unsloth_zoo/llama_cpp.py:281::install_package` tries to read the install confirmation prompt. The exception propagates up through save_pretrained_gguf and surfaces as RuntimeError: Unsloth: GGUF conversion failed: EOF when reading a line This blocks every notebook that calls `model.save_pretrained_gguf(...)` from a docker run. Catch EOFError and default to accept (same as pressing ENTER); opt out with UNSLOTH_AUTO_INSTALL=0 for callers that want the explicit refusal behaviour back. Repro: nb/Llama3_(8B)-Ollama.ipynb on the unsloth-blackwell:no-vllm image. --- unsloth_zoo/llama_cpp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 98c265d63..e6830d46e 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -278,7 +278,18 @@ def install_package(package, sudo = False, print_output = False, print_outputs = print(f"Unsloth: Installing packages: {package}") if not (IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT): - acceptance = input(f"Missing system packages. We need to execute `{install_cmd}` - do you accept? Press ENTER. Type NO if not.") + # Non-interactive contexts (Docker w/o TTY, headless CI) raise + # EOFError on input(). Treat that like an implicit ENTER ie accept + # the install. Opt out via UNSLOTH_AUTO_INSTALL=0. + try: + acceptance = input(f"Missing system packages. We need to execute `{install_cmd}` - do you accept? Press ENTER. Type NO if not.") + except EOFError: + if os.environ.get("UNSLOTH_AUTO_INSTALL", "1") != "1": + raise RuntimeError( + f"Unsloth: Execution of `{install_cmd}` was cancelled (no TTY and UNSLOTH_AUTO_INSTALL=0)!\n"\ + "Please install llama.cpp manually via https://docs.unsloth.ai/basics/troubleshooting-and-faqs#how-do-i-manually-save-to-gguf" + ) + acceptance = "" if "no" in str(acceptance).lower(): raise RuntimeError( f"Unsloth: Execution of `{install_cmd}` was cancelled!\n"\ From 1d96f7cc0d24b6d03e8e18c1eb63a9681637e5b7 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 24 May 2026 16:21:02 +0000 Subject: [PATCH 05/12] patch_deepseek_v2_moe_capitalisation_alias: bridge transformers 4 -> 5 rename transformers 5.0 renamed `DeepseekV2MoE` -> `DeepseekV2Moe` (camelCase consistency pass). Remote-code model files that still import the old spelling (e.g. deepseek-ai/DeepSeek-OCR's modeling_deepseekocr.py:22) break on transformers 5.x with `ImportError: cannot import name 'DeepseekV2MoE' ... Did you mean: 'DeepseekV2Moe'?`. Add a class-level alias inside transformers' deepseek_v2 namespace so the legacy name keeps resolving. No-op on transformers 4.x where `DeepseekV2MoE` already exists natively. Repro: nb/Deepseek_OCR_(3B).ipynb on transformers 5.5.0. --- unsloth_zoo/temporary_patches/misc.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index 59f3d0f04..6d0263487 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -1506,3 +1506,24 @@ def _min_pixels(self): pass pass TEMPORARY_PATCHES.append(patch_qwen2vl_image_processor_pixel_attrs) + + +def patch_deepseek_v2_moe_capitalisation_alias(): + """ + transformers 5.0 renamed `DeepseekV2MoE` -> `DeepseekV2Moe` (camelCase + consistency pass). Remote-code models like deepseek-ai/DeepSeek-OCR + ship a modeling file that still imports the old name, so loading them + on transformers 5.x raises `ImportError: cannot import name + 'DeepseekV2MoE'`. Add a backward-compat alias so the old name keeps + resolving regardless of which transformers version is installed. + Forward-compatible: when transformers 4.x is installed and ships + `DeepseekV2MoE` natively, the alias check is a no-op. + """ + try: + from transformers.models.deepseek_v2 import modeling_deepseek_v2 as _m + except ImportError: + return + if not hasattr(_m, "DeepseekV2MoE") and hasattr(_m, "DeepseekV2Moe"): + _m.DeepseekV2MoE = _m.DeepseekV2Moe +pass +TEMPORARY_PATCHES.append(patch_deepseek_v2_moe_capitalisation_alias) From 6dec60648d4c8d38293da9e499703e6a7197d5ed Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 07:41:45 +0000 Subject: [PATCH 06/12] notebook_deps + vllm decompose_size_nodes monkey patches Two new zoo-level patches that fix six of the GRPO/vision notebook regressions caught by the Blackwell docker validation: 1. unsloth_zoo/vllm_utils.py::patch_vllm_decompose_size_nodes Port of upstream vllm-project/vllm#42543 (still open, SHA 51fd86e). `_decompose_size_nodes` only rewrites top-level `user.args` references to size nodes, missing the ones nested inside `slice(start, stop, step)`, tuples, lists, and kwargs. The trailing `erase_node` then trips torch.fx with RuntimeError: Tried to erase Node size_N but it still had K users on the 5 GRPO notebooks (Qwen3-4B-GRPO, Advanced Llama3.2 GRPO LoRA, Qwen3 8B FP8 GRPO, Llama FP8 GRPO, Qwen2.5-7B VL GRPO). Qwen3-VL- (8B) Vision-GRPO is unaffected because its graph only emits getitem (size, idx) consumers, which the upstream `if` branch already handles. The wrapper ports PR #42543's recursive rewrite, adds a kwargs sweep PR #42543 misses, and falls back to a warn-and-skip safety net (`if node.users: continue`) so the working notebooks stay a no-op. Version-guarded (0.11.0 <= vllm < 0.99), idempotent (`_unsloth_patched` sentinel), opt-out via `UNSLOTH_DISABLE_VLLM_DECOMPOSE_SIZE_PATCH=1`. MLX-safe: the patch sits inside `patch_vllm()` which is only called when `import vllm` succeeds. 2. unsloth_zoo/temporary_patches/notebook_deps.py Three thin auto-install hooks that catch the upstream-blessed import failure points and pip-install one of an allow-listed set of optional notebook deps before re-raising: - `transformers.utils.import_utils.requires_backends` -> fixes timm (TimmWrapper inside Gemma3N + Qwen3-VL). - `transformers.dynamic_module_utils.check_imports` -> fixes addict + matplotlib for trust_remote_code modeling files (Deepseek-OCR). - Pre-emptive `_ensure_notebook_chain()` -> ensures `traitlets` is importable before the IPython chain pulls it in (Gemma3 Vision + Qwen3-VL Vision). Honours the existing `UNSLOTH_AUTO_INSTALL=0` opt-out (matches `llama_cpp.py::install_package`) and the standard offline flags `UNSLOTH_OFFLINE` / `HF_HUB_OFFLINE` / `TRANSFORMERS_OFFLINE`. Heavy / GPU-arch-coupled deps (torch, vllm, flash-attn, bnb, triton, xformers) are explicitly excluded from the allow-list so we never paper over a real CUDA/driver mismatch. Cross-OS: prefers `uv pip install` only when a venv is active, otherwise `sys.executable -m pip install`. Probes site-packages writability and adds `--user` when needed. Windows has no `os.geteuid` so the probe just stays in the venv path. Both module-level transformers imports are inside the patch functions, so MLX-only macOS environments without transformers import the module cleanly (no-op). Repro: - vllm: nb/Qwen3_(4B)-GRPO.ipynb on unsloth-blackwell:test against vllm 0.11.x / nightly cu128. - deps: nb/Gemma3_(4B)-Vision.ipynb, nb/Gemma3N_(4B)-Vision.ipynb, nb/Qwen3_VL_(8B)-Vision.ipynb, nb/Deepseek_OCR_(3B).ipynb on a fresh Colab kernel. --- unsloth_zoo/temporary_patches/__init__.py | 1 + .../temporary_patches/notebook_deps.py | 220 ++++++++++++++++++ unsloth_zoo/vllm_utils.py | 119 ++++++++++ 3 files changed, 340 insertions(+) create mode 100644 unsloth_zoo/temporary_patches/notebook_deps.py diff --git a/unsloth_zoo/temporary_patches/__init__.py b/unsloth_zoo/temporary_patches/__init__.py index 2e8fced03..ab5494c63 100644 --- a/unsloth_zoo/temporary_patches/__init__.py +++ b/unsloth_zoo/temporary_patches/__init__.py @@ -16,6 +16,7 @@ from .common import * +from .notebook_deps import * from .gemma import * from .misc import * from .gemma3n import * diff --git a/unsloth_zoo/temporary_patches/notebook_deps.py b/unsloth_zoo/temporary_patches/notebook_deps.py new file mode 100644 index 000000000..472f1b313 --- /dev/null +++ b/unsloth_zoo/temporary_patches/notebook_deps.py @@ -0,0 +1,220 @@ +# Auto-install missing notebook-only Python deps on first use. +# +# Four notebooks failed in the Blackwell docker validation because the slim +# venv shipped without timm / traitlets / addict / matplotlib, and the +# raising frame is buried inside HF code (`transformers.utils.import_utils. +# requires_backends` for TimmWrapper, `transformers.dynamic_module_utils. +# check_imports` for the Deepseek-OCR trust_remote_code modeling file, and +# a bare ModuleNotFoundError for traitlets from the IPython chain). Wrap +# all three call sites with a thin retry that pip-installs the offending +# package (allow-list only) and re-tries the original import. Honours the +# existing `UNSLOTH_AUTO_INSTALL=0` opt-out (used by `llama_cpp.py`) and +# the standard offline flags so air-gapped envs keep emitting the +# upstream ImportError verbatim. + +import importlib +import importlib.metadata +import importlib.util +import os +import shutil +import site +import subprocess +import sys + +from ..log import logger + +# pypi-name -> import-name (None means same). +_ALLOW_LIST = { + "timm": None, # vision backbones (TimmWrapperModel) + "addict": None, # Deepseek-OCR config dicts + "matplotlib": None, # Deepseek-OCR + a few HF image utils + "traitlets": None, # Jupyter/IPython widget chain + "soundfile": None, # audio processors + "librosa": None, # audio processors + "scipy": None, # several processors + "pyctcdecode": None, # ASR + "tiktoken": None, # tokenizer remote-code paths + "blobfile": None, # tiktoken backing store + "pillow_heif": "pillow_heif", # HEIF images + "decord": None, # video processors + "av": "av", # pyav (video processors) + "num2words": None, # speech text norm + "jieba": None, # zh tokenizer + "sentencepiece": None, # tokenizers +} + +_AUTO_INSTALL = os.environ.get("UNSLOTH_AUTO_INSTALL", "1") == "1" +_NO_NETWORK = ( + os.environ.get("UNSLOTH_OFFLINE", "0") == "1" + or os.environ.get("HF_HUB_OFFLINE", "0") == "1" + or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1" +) +_attempted: set = set() + + +def _in_venv() -> bool: + return ( + hasattr(sys, "real_prefix") + or (getattr(sys, "base_prefix", sys.prefix) != sys.prefix) + or bool(os.environ.get("VIRTUAL_ENV")) + or bool(os.environ.get("CONDA_PREFIX")) + ) + + +def _pip_install(pkg: str) -> bool: + if pkg in _attempted: + return False + _attempted.add(pkg) + if shutil.which("uv") and _in_venv(): + cmd = ["uv", "pip", "install", "--quiet", pkg] + else: + cmd = [ + sys.executable, "-m", "pip", "install", "--quiet", + "--disable-pip-version-check", "--no-input", pkg, + ] + # Outside a venv on Linux/Mac as non-root: probe write access to + # site-packages and fall back to --user. Windows has no geteuid; + # site-packages there is usually writable inside the venv anyway. + if not _in_venv() and hasattr(os, "geteuid") and os.geteuid() != 0: + try: + sp = site.getsitepackages()[0] + probe = os.path.join(sp, ".unsloth_write_probe") + open(probe, "w").close() + os.remove(probe) + except Exception: + cmd.append("--user") + logger.warning( + f"Unsloth: auto-installing missing notebook dep `{pkg}` via " + f"`{' '.join(cmd)}`. Set UNSLOTH_AUTO_INSTALL=0 to disable." + ) + try: + r = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + except Exception as e: + logger.warning(f"Unsloth: auto-install of `{pkg}` failed to launch: {e}") + return False + if r.returncode != 0: + tail = (r.stderr or "")[-500:] + logger.warning(f"Unsloth: auto-install of `{pkg}` failed:\n{tail}") + return False + importlib.invalidate_caches() + try: + list(importlib.metadata.distributions()) + except Exception: + pass + return True + + +def _try_install_and_import(pkg: str) -> bool: + if pkg not in _ALLOW_LIST: + return False + if not _AUTO_INSTALL or _NO_NETWORK: + return False + import_name = _ALLOW_LIST[pkg] or pkg.replace("-", "_") + if importlib.util.find_spec(import_name) is not None: + return True + if not _pip_install(pkg): + return False + return importlib.util.find_spec(import_name) is not None + + +def patch_requires_backends_autoinstall(): + """ + Wrap ``transformers.utils.import_utils.requires_backends`` so that an + allow-listed missing backend triggers a one-shot pip install and a + second attempt. Preserves the original ImportError when the install + fails or the dep isn't on the allow-list, so user-facing error bytes + stay identical to upstream when ``UNSLOTH_AUTO_INSTALL=0``. + """ + try: + from transformers.utils import import_utils as iu + except Exception: + return # transformers absent (MLX-only path) -- nothing to patch. + if getattr(iu.requires_backends, "_unsloth_patched", False): + return + _orig = iu.requires_backends + + def requires_backends(obj, backends): + try: + return _orig(obj, backends) + except ImportError: + if not _AUTO_INSTALL or _NO_NETWORK: + raise + wanted_iter = backends if isinstance(backends, (list, tuple)) else [backends] + wanted = [b for b in wanted_iter if isinstance(b, str) and b in _ALLOW_LIST] + if not wanted: + raise + installed_any = False + for b in wanted: + if _try_install_and_import(b): + installed_any = True + if not installed_any: + raise + for b in wanted: + flag = f"_{b.replace('-', '_')}_available" + if hasattr(iu, flag): + setattr(iu, flag, True) + return _orig(obj, backends) + + requires_backends._unsloth_patched = True + iu.requires_backends = requires_backends + + +def patch_check_imports_autoinstall(): + """ + trust_remote_code modeling files (e.g. Deepseek-OCR's modeling_deepseekocr.py) + declare their import requirements at the top of the file and raise via + ``dynamic_module_utils.check_imports`` (ImportError "This modeling file + requires the following packages..."). That call site never reaches + ``requires_backends``, so wrap it too. + """ + try: + from transformers import dynamic_module_utils as dmu + except Exception: + return + if getattr(dmu.check_imports, "_unsloth_patched", False): + return + _orig = dmu.check_imports + + def check_imports(filename): + try: + return _orig(filename) + except ImportError as e: + if not _AUTO_INSTALL or _NO_NETWORK: + raise + msg = str(e) + if "This modeling file requires" not in msg: + raise + # Message format: "... environment: pkg1, pkg2. Run `pip install...`" + try: + tail = msg.split("environment:", 1)[1] + pkgs_str = tail.split(".", 1)[0] + except Exception: + raise + pkgs = [p.strip() for p in pkgs_str.split(",") if p.strip() in _ALLOW_LIST] + if not pkgs: + raise + ok = all(_try_install_and_import(p) for p in pkgs) + if not ok: + raise + return _orig(filename) + + check_imports._unsloth_patched = True + dmu.check_imports = check_imports + + +def _ensure_notebook_chain(): + """ + Pre-emptive ensure for deps that raise bare ModuleNotFoundError outside + transformers (the Jupyter/IPython chain). Kept tiny: only ``traitlets`` + is touched today; expand only when a new failure mode appears. + """ + if not _AUTO_INSTALL or _NO_NETWORK: + return + for pkg in ("traitlets",): + if importlib.util.find_spec(pkg) is None: + _try_install_and_import(pkg) + + +patch_requires_backends_autoinstall() +patch_check_imports_autoinstall() +_ensure_notebook_chain() diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index f8b3362d6..d00af9db0 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -785,6 +785,124 @@ def capture_model_wrapper_v0(self, *args, **kwargs): pass +def patch_vllm_decompose_size_nodes(): + """ + Workaround for vLLM upstream bug: ``_decompose_size_nodes`` only rewrites + top-level ``user.args`` references to size nodes, missing the ones nested + inside ``slice(start, stop, step)``, tuples/lists, and kwargs. The trailing + ``graph.graph.erase_node(node)`` then trips ``torch.fx`` with + ``RuntimeError: Tried to erase Node size_N but it still had K users``. + Manifests on Qwen3-(4B)-GRPO, Advanced Llama3.2 GRPO LoRA, Qwen3 8B FP8 GRPO, + Llama FP8 GRPO, Qwen2.5-7B VL GRPO during ``FastVisionModel.from_pretrained + (fast_inference=True)``. Qwen3-VL-(8B) Vision-GRPO is unaffected because its + graph only emits ``getitem(size, idx)`` users (handled by the upstream + ``if`` branch). Canonical upstream fix: vllm-project/vllm#42543 (open at + SHA 51fd86e). Port the recursive rewrite + add a kwargs sweep PR #42543 + misses + final ``if node.users: warn-and-skip`` safety net. + + Opt out with ``UNSLOTH_DISABLE_VLLM_DECOMPOSE_SIZE_PATCH=1``. + """ + if os.environ.get("UNSLOTH_DISABLE_VLLM_DECOMPOSE_SIZE_PATCH", "0") == "1": + return + try: + import vllm + except ImportError: + return + if not (Version("0.11.0") <= Version(vllm_version) < Version("0.99")): + return + try: + import vllm.compilation.backends as _B + except Exception: + return + if not hasattr(_B, "_decompose_size_nodes"): + return + _orig = _B._decompose_size_nodes + if getattr(_orig, "_unsloth_patched", False): + return + try: + import operator + from torch import fx + except Exception: + return + + def _replace_in_slice(s, node, dims): + def _sub(b): + if isinstance(b, fx.Node) and b is node: + sym = [d for d in dims if isinstance(d, fx.Node)] + return sym[0] if len(sym) == 1 else dims[0] + return b + ns, no, nt = _sub(s.start), _sub(s.stop), _sub(s.step) + if (ns, no, nt) != (s.start, s.stop, s.step): + return slice(ns, no, nt) + return s + + def _replace_in_args(args, node, dims): + out = [] + for a in args: + if isinstance(a, fx.Node) and a is node: + out.extend(dims) + elif isinstance(a, slice): + out.append(_replace_in_slice(a, node, dims)) + elif isinstance(a, (tuple, list)): + out.append(type(a)(_replace_in_args(list(a), node, dims))) + else: + out.append(a) + return out + + def _decompose_size_nodes(graph): + size_nodes = list(graph.graph.find_nodes(op="call_method", target="size")) + for node in size_nodes: + tensor_node = node.args[0] + ev = tensor_node.meta.get("example_value") + if ev is None: + continue + dims = [] + with graph.graph.inserting_after(tensor_node): + for i in range(ev.dim()): + dv = ev.shape[i] + if isinstance(dv, torch.SymInt): + dn = graph.graph.call_function( + torch.ops.aten.sym_size.int, args = (tensor_node, i), + ) + dn.meta["example_value"] = dv + dims.append(dn) + else: + dims.append(int(dv)) + for user in list(node.users): + if ( + user.op == "call_function" + and user.target is operator.getitem + and len(user.args) == 2 + and user.args[0] is node + ): + user.replace_all_uses_with(dims[user.args[1]]) + graph.graph.erase_node(user) + else: + user.args = tuple(_replace_in_args(list(user.args), node, dims)) + if user.kwargs: + new_kwargs = {} + for k, v in user.kwargs.items(): + if isinstance(v, (tuple, list)): + new_kwargs[k] = type(v)(_replace_in_args(list(v), node, dims)) + else: + new_kwargs[k] = _replace_in_args([v], node, dims)[0] + user.kwargs = new_kwargs + if node.users: + logger.warning( + f"Unsloth: vllm _decompose_size_nodes left {node.name} " + f"with users {dict(node.users)}; skipping erase. " + f"See vllm-project/vllm#42543." + ) + continue + graph.graph.erase_node(node) + pass + + _decompose_size_nodes._unsloth_patched = True + _B._decompose_size_nodes = _decompose_size_nodes + logger.info("Unsloth: patched vllm.compilation.backends._decompose_size_nodes (vllm#42543).") +pass + + def patch_vllm(debug = True): # Temporary patch to disable multiprocessing for vLLM # Allows accessing model_executor @@ -814,6 +932,7 @@ def patch_vllm(debug = True): logger.info(f'Unsloth: Patching vLLM to enable standby.') patch_vllm_enable_sleep_mode() patch_vllm_graph_capture() + patch_vllm_decompose_size_nodes() global LORA_REQUEST_ID LORA_REQUEST_ID = 1 pass From 30d603eee385b731c342958c3e3b3ffd42c13c46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 07:51:08 +0000 Subject: [PATCH 07/12] load_vllm: force off FlashInfer when nvcc / ninja missing The Blackwell docker image (and most cu128 runtime images) ship the CUDA *runtime* libs but not nvcc, and FlashInfer requires a JIT compile of its trtllm-gen kernels at first use. The pre-flight at the top of `load_vllm` was already detecting the missing toolchain and printing a warning, but the cleanup only ran if VLLM_USE_FLASHINFER_SAMPLER == "1": del ... if VLLM_ATTENTION_BACKEND == "FLASHINFER": del ... That `del` cleanup is not enough on vLLM nightly: the v1 engine picks FlashInfer as the *default* attention backend on sm_100 / sm_120 (Blackwell). When neither env var is pre-set, the `del` branch never fires and vLLM proceeds to JIT-compile FlashInfer kernels, which crashes with RuntimeError: FlashInfer failed to JIT-compile: ninja (build tool) not found. (reproduced on `unsloth-blackwell:test` running Qwen3-(4B)-GRPO). Replace the conditional `del`s with explicit pins: VLLM_ATTENTION_BACKEND = FLASH_ATTN VLLM_USE_FLASHINFER_SAMPLER = 0 UNSLOTH_VLLM_NO_FLASHINFER = 1 so vLLM picks the FLASH_ATTN path (no JIT) and any downstream code in unsloth_zoo / load_vllm sees a consistent disable signal. No-op when nvcc and ninja are both present: that else-branch is untouched. Compatible with vllm 0.9 -> nightly; transformers 4.57.6 / 5.x / main; TRL 0.22.2 / 0.27.1 / 1.x. MLX-safe: this is inside `load_vllm`, only called when `fast_inference=True`, which never fires on the MLX path. --- unsloth_zoo/vllm_utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d00af9db0..26bcc57d6 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -2100,11 +2100,18 @@ def load_vllm( f" ninja - pip install ninja\n" f" To silence this warning: set UNSLOTH_VLLM_NO_FLASHINFER=1" ) - # Clear any externally-set FlashInfer env vars so vLLM uses defaults - if os.environ.get("VLLM_USE_FLASHINFER_SAMPLER", "") == "1": - del os.environ["VLLM_USE_FLASHINFER_SAMPLER"] - if os.environ.get("VLLM_ATTENTION_BACKEND", "") == "FLASHINFER": - del os.environ["VLLM_ATTENTION_BACKEND"] + # Force vLLM off FlashInfer when nvcc/ninja are missing. + # `del`-ing the vars wasn't enough: vLLM's nightly v1 engine + # picks FlashInfer as the *default* attention backend on + # sm_100/sm_120 (Blackwell), then still JIT-compiles the + # trtllm-gen kernels and crashes inside `vllm.LLM()`. Pin the + # backend to FLASH_ATTN and the sampler off so vLLM never even + # tries to import FlashInfer's JIT path. Also propagate to + # UNSLOTH_VLLM_NO_FLASHINFER so any downstream code path in + # unsloth_zoo / load_vllm sees a consistent disable signal. + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0" + os.environ["UNSLOTH_VLLM_NO_FLASHINFER"] = "1" else: # Check if FLASHINFER is supported - for eg Qwen3-VL and Qwen2-VL do not work if "VLLM_ATTENTION_BACKEND" in os.environ and os.environ["VLLM_ATTENTION_BACKEND"] == "": From 08d584404e15188d028624a17a2f1021643fa7de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 08:00:00 +0000 Subject: [PATCH 08/12] notebook_deps: add einops to the auto-install allow-list Deepseek-OCR's modeling_deepseekocr.py imports einops at the top of the file (via deepencoder.py:11 `from einops import rearrange`). It is not in transformers' BACKENDS_MAPPING and not declared in the "This modeling file requires ..." line that check_imports parses, so neither of the existing zoo hooks fires for it. Add `einops` to _ALLOW_LIST so the requires_backends / check_imports wrappers install it on first failure. --- unsloth_zoo/temporary_patches/notebook_deps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches/notebook_deps.py b/unsloth_zoo/temporary_patches/notebook_deps.py index 472f1b313..d1a46e8cb 100644 --- a/unsloth_zoo/temporary_patches/notebook_deps.py +++ b/unsloth_zoo/temporary_patches/notebook_deps.py @@ -27,6 +27,7 @@ _ALLOW_LIST = { "timm": None, # vision backbones (TimmWrapperModel) "addict": None, # Deepseek-OCR config dicts + "einops": None, # Deepseek-OCR deepencoder + many other vision models "matplotlib": None, # Deepseek-OCR + a few HF image utils "traitlets": None, # Jupyter/IPython widget chain "soundfile": None, # audio processors From fd599618b41c0c22851b7541b188e972b68e1897 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 08:20:52 +0000 Subject: [PATCH 09/12] notebook_deps: add easydict to the auto-install allow-list Deepseek-OCR's deepencoder.py:12 imports `from easydict import EasyDict as adict`. easydict and addict are DIFFERENT PyPI packages -- a notebook can require either, and Deepseek-OCR specifically needs easydict (despite the alias name suggesting addict). einops + easydict together unblock Deepseek-OCR's trust_remote_code modeling file at first load. --- unsloth_zoo/temporary_patches/notebook_deps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches/notebook_deps.py b/unsloth_zoo/temporary_patches/notebook_deps.py index d1a46e8cb..97e891d94 100644 --- a/unsloth_zoo/temporary_patches/notebook_deps.py +++ b/unsloth_zoo/temporary_patches/notebook_deps.py @@ -28,6 +28,7 @@ "timm": None, # vision backbones (TimmWrapperModel) "addict": None, # Deepseek-OCR config dicts "einops": None, # Deepseek-OCR deepencoder + many other vision models + "easydict": None, # Deepseek-OCR deepencoder.py:12 `from easydict import EasyDict` "matplotlib": None, # Deepseek-OCR + a few HF image utils "traitlets": None, # Jupyter/IPython widget chain "soundfile": None, # audio processors From d17e56870c46d5ceb82b4d00dfcb37f360b2fb81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 08:33:01 +0000 Subject: [PATCH 10/12] notebook_deps: add snac to the auto-install allow-list Orpheus TTS notebook imports `from snac import SNAC` after model load to decode the speech tokens. snac is a small (~1MB) pure-Python wheel with torch as the only heavy dep, so safe to auto-install. --- unsloth_zoo/temporary_patches/notebook_deps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches/notebook_deps.py b/unsloth_zoo/temporary_patches/notebook_deps.py index 97e891d94..f66c3cbcc 100644 --- a/unsloth_zoo/temporary_patches/notebook_deps.py +++ b/unsloth_zoo/temporary_patches/notebook_deps.py @@ -29,6 +29,7 @@ "addict": None, # Deepseek-OCR config dicts "einops": None, # Deepseek-OCR deepencoder + many other vision models "easydict": None, # Deepseek-OCR deepencoder.py:12 `from easydict import EasyDict` + "snac": None, # Orpheus TTS neural audio codec "matplotlib": None, # Deepseek-OCR + a few HF image utils "traitlets": None, # Jupyter/IPython widget chain "soundfile": None, # audio processors From 4f7b57c6f78469fbe31462c00edecd4fba9ac8f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 08:45:44 +0000 Subject: [PATCH 11/12] load_vllm: block `import flashinfer` when nvcc/ninja are missing Setting VLLM_ATTENTION_BACKEND=FLASH_ATTN doesn't work on vLLM 0.19.1: envs.py reports "Unknown vLLM environment variable detected: VLLM_ATTENTION_BACKEND", and vLLM still picks FLASHINFER from its ['FLASHINFER', 'FLASH_ATTN', 'TRITON_ATTN', 'FLEX_ATTENTION'] default list on sm_100/sm_120 (cuda.py:334), then JIT-compiles trtllm-gen kernels and dies on `/usr/local/cuda/bin/nvcc: not found`. Block the flashinfer module instead. Drop any cached imports from sys.modules, then set `sys.modules["flashinfer"] = None` so subsequent `import flashinfer` raises ImportError. vLLM's attention-backend selector falls back to FLASH_ATTN -> no JIT, no crash. This is the documented Python idiom for "this module is unavailable" (see Python language reference, "The module cache" section). Reversible per-process; the only caller is the load_vllm pre-flight, which runs once. Repro: nb/Qwen3_(4B)-GRPO.ipynb on unsloth-blackwell:test where the runtime image ships /usr/local/cuda runtime libs but no nvcc. --- unsloth_zoo/vllm_utils.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 26bcc57d6..1e89ae9e0 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -2101,17 +2101,30 @@ def load_vllm( f" To silence this warning: set UNSLOTH_VLLM_NO_FLASHINFER=1" ) # Force vLLM off FlashInfer when nvcc/ninja are missing. - # `del`-ing the vars wasn't enough: vLLM's nightly v1 engine - # picks FlashInfer as the *default* attention backend on - # sm_100/sm_120 (Blackwell), then still JIT-compiles the - # trtllm-gen kernels and crashes inside `vllm.LLM()`. Pin the - # backend to FLASH_ATTN and the sampler off so vLLM never even - # tries to import FlashInfer's JIT path. Also propagate to - # UNSLOTH_VLLM_NO_FLASHINFER so any downstream code path in - # unsloth_zoo / load_vllm sees a consistent disable signal. - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + # Env-var nudging is not enough: `VLLM_ATTENTION_BACKEND` is + # not recognised by vllm 0.19.1 (envs.py reports "Unknown + # vLLM environment variable detected") and vLLM still picks + # FLASHINFER from `['FLASHINFER', 'FLASH_ATTN', 'TRITON_ATTN', + # 'FLEX_ATTENTION']` on sm_100/sm_120, then JIT-compiles the + # trtllm-gen kernels and crashes inside `vllm.LLM()`. Block + # `import flashinfer` at the module level so vLLM's + # `try: import flashinfer except ImportError` branch in + # `vllm.platforms.cuda.get_attn_backend_cls` picks FLASH_ATTN + # instead. Also clear any user-set env vars and propagate to + # UNSLOTH_VLLM_NO_FLASHINFER for the rest of unsloth_zoo. os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0" os.environ["UNSLOTH_VLLM_NO_FLASHINFER"] = "1" + try: + # Drop any cached flashinfer module then mark it None so + # `import flashinfer` raises ImportError. None-in-sys.modules + # is the documented Python idiom for "this module fails to + # import"; see https://docs.python.org/3/reference/import.html. + for _name in list(sys.modules): + if _name == "flashinfer" or _name.startswith("flashinfer."): + del sys.modules[_name] + sys.modules["flashinfer"] = None + except Exception: + pass else: # Check if FLASHINFER is supported - for eg Qwen3-VL and Qwen2-VL do not work if "VLLM_ATTENTION_BACKEND" in os.environ and os.environ["VLLM_ATTENTION_BACKEND"] == "": From b8e681cd2a3eacd1878e004545e3853cf3ee193e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 08:50:05 +0000 Subject: [PATCH 12/12] notebook_deps: add torchcodec to the auto-install allow-list (HF datasets 4.x audio Feature decoder) --- unsloth_zoo/temporary_patches/notebook_deps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches/notebook_deps.py b/unsloth_zoo/temporary_patches/notebook_deps.py index f66c3cbcc..1bfa04b13 100644 --- a/unsloth_zoo/temporary_patches/notebook_deps.py +++ b/unsloth_zoo/temporary_patches/notebook_deps.py @@ -30,6 +30,7 @@ "einops": None, # Deepseek-OCR deepencoder + many other vision models "easydict": None, # Deepseek-OCR deepencoder.py:12 `from easydict import EasyDict` "snac": None, # Orpheus TTS neural audio codec + "torchcodec": None, # HF datasets audio Feature decoder (>= datasets 4.x) "matplotlib": None, # Deepseek-OCR + a few HF image utils "traitlets": None, # Jupyter/IPython widget chain "soundfile": None, # audio processors