Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions unsloth_zoo/patching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .compiler import UNSLOTH_COMPILE_LOCATION
from .utils import _get_dtype, Version
from .hf_utils import dtype_from_config, set_dtype_in_config, HAS_TORCH_DTYPE
from .temporary_patches.moe_utils_fp8 import maybe_patch_stacked_moe_expert_fp8_scales

# Also disable compiling on bitsandbytes
def patch_compiling_bitsandbytes():
Expand Down Expand Up @@ -396,6 +397,8 @@ def __fix_dtype(config):
# string when trying to save the config or serialize it
patch_to_dict()

maybe_patch_stacked_moe_expert_fp8_scales(model)

# Check all params and patch!
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
Expand Down
8 changes: 6 additions & 2 deletions unsloth_zoo/temporary_patches/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ def moe_block_forward(self, hidden_states) -> torch.Tensor:
return hidden_states + shared_output

# Apply patches
patch_function(Glm4MoeLiteNaiveMoe, "forward", get_forward_moe_backend())
patch_function(Glm4MoeLiteMoE, "forward", moe_block_forward)
# Recent transformers wraps the expert forward with use_experts_implementation
# and drops some annotations, so strict signature matching rejects the patch.
# For GLM4 we want to bypass that wrapper entirely and route into Unsloth's
# MoE backend on purpose.
patch_function(Glm4MoeLiteNaiveMoe, "forward", get_forward_moe_backend(), force = True)
patch_function(Glm4MoeLiteMoE, "forward", moe_block_forward, force = True)

if UNSLOTH_ENABLE_LOGGING:
logger.info("Unsloth: Patched GLM4 MoE for Split LoRA support.")
Expand Down
20 changes: 13 additions & 7 deletions unsloth_zoo/temporary_patches/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,22 @@ def patch_transformers_masks():
except Exception:
torch_create_block_mask = None

# Always disable _compile flag to avoid double compilation issues
# When unsloth compiles create_causal_mask, the internal create_block_mask
# should NOT also compile itself as it causes dimension issues
# We need to patch both masking_utils and the original torch module
if torch_create_block_mask is not None:
def create_block_mask_wrapper(*args, **kwargs):
kwargs["_compile"] = False
return torch_create_block_mask(*args, **kwargs)
# Patch masking_utils (for direct access)
masking_utils.create_block_mask = create_block_mask_wrapper
# Also patch the torch module directly (used by flex_attention_mask via import)
try:
supports_compile = "_compile" in inspect.signature(torch_create_block_mask).parameters
import torch.nn.attention.flex_attention as flex_attention
flex_attention.create_block_mask = create_block_mask_wrapper
except Exception:
supports_compile = True
if not supports_compile:
def create_block_mask_wrapper(*args, **kwargs):
kwargs.pop("_compile", None)
return torch_create_block_mask(*args, **kwargs)
masking_utils.create_block_mask = create_block_mask_wrapper
pass

original_create_causal_mask = getattr(
masking_utils,
Expand Down
Loading
Loading