From 4f8e15902d3608f931a9b3bc5707d2274f681912 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 15 Mar 2026 13:49:31 +0000 Subject: [PATCH 1/8] [WIP] Fp8 training for Moe --- unsloth_zoo/temporary_patches/glm4_moe.py | 112 ++++++ unsloth_zoo/temporary_patches/moe_utils.py | 379 +++++++++++++++++++++ 2 files changed, 491 insertions(+) diff --git a/unsloth_zoo/temporary_patches/glm4_moe.py b/unsloth_zoo/temporary_patches/glm4_moe.py index 1b485013d..8947a70af 100644 --- a/unsloth_zoo/temporary_patches/glm4_moe.py +++ b/unsloth_zoo/temporary_patches/glm4_moe.py @@ -14,13 +14,125 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import os import torch +import torch.nn as nn from .common import TEMPORARY_PATCHES, torch_compile, UNSLOTH_ENABLE_LOGGING from .utils import patch_function, raise_error, logger from .moe_utils import ( patch_param_wrapper_for_moe, get_forward_moe_backend, ) + + +def maybe_patch_glm4_moe_expert_fp8_scales( + model, + model_name: str, + token = None, + revision = None, +): + """ + GLM-4.7-Flash FP8 Dynamic stores routed expert weights as raw float8 tensors + plus per-expert weight_scale tensors. Transformers currently leaves those + scales as UNEXPECTED keys because Glm4MoeLiteNaiveMoe uses stacked + nn.Parameters instead of Linear modules, so we patch the expert tensors here. + We must preserve the FP8 parameters and attach the scale tensors for runtime + dequantization only when a fallback path needs high-precision weights. + """ + config = getattr(model, "config", None) + if config is None or getattr(config, "model_type", None) != "glm4_moe_lite": + return False + + quantization_config = getattr(config, "quantization_config", None) + if isinstance(quantization_config, dict): + quant_method = quantization_config.get("quant_method", None) + else: + quant_method = getattr(quantization_config, "quant_method", None) + if quant_method != "compressed-tensors": + return False + + inner_model = getattr(model, "model", None) + if inner_model is None or not hasattr(inner_model, "layers"): + return False + + routed_layers = [] + for layer_idx, layer in enumerate(inner_model.layers): + experts = getattr(getattr(layer, "mlp", None), "experts", None) + if experts is None or not hasattr(experts, "gate_up_proj"): + continue + if getattr(experts.gate_up_proj, "dtype", None) == torch.float8_e4m3fn: + routed_layers.append((layer_idx, experts)) + if len(routed_layers) == 0: + return False + + if os.path.isdir(model_name): + safetensors_path = os.path.join(model_name, "model.safetensors") + else: + from huggingface_hub import hf_hub_download + + safetensors_path = hf_hub_download( + repo_id = model_name, + filename = "model.safetensors", + token = token, + revision = revision, + ) + + import safetensors.torch + + with safetensors.torch.safe_open(safetensors_path, framework = "pt") as file: + for layer_idx, experts in routed_layers: + device = experts.gate_up_proj.device + num_experts = experts.gate_up_proj.shape[0] + gate_up_rows = [] + down_rows = [] + gate_up_scales = [] + down_scales = [] + + for expert_idx in range(num_experts): + gate = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" + ) + gate_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale" + ) + up = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" + ) + up_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale" + ) + down = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale" + ) + + gate_up_rows.append(torch.cat([gate, up], dim = 0)) + down_rows.append(down) + gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0)) + down_scales.append(down_scale) + + experts.gate_up_proj = nn.Parameter( + torch.stack(gate_up_rows, dim = 0).to(device = device), + requires_grad = experts.gate_up_proj.requires_grad, + ) + experts.down_proj = nn.Parameter( + torch.stack(down_rows, dim = 0).to(device = device), + requires_grad = experts.down_proj.requires_grad, + ) + experts.gate_up_proj_weight_scale = nn.Parameter( + torch.stack(gate_up_scales, dim = 0).to(device = device), + requires_grad = False, + ) + experts.down_proj_weight_scale = nn.Parameter( + torch.stack(down_scales, dim = 0).to(device = device), + requires_grad = False, + ) + + return True + + def patch_glm4_moe(): """ Patches GLM4 MoE to support Split LoRA using grouped GEMM. diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 1954522a9..767be0d95 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -69,6 +69,14 @@ def install_to_cache(source_path, destination_filename=None): _CACHED_FORWARD_MOE_BACKEND = None _CACHED_MOE_UTILS_MODULE = None +_WARNED_MOE_MESSAGES = set() + + +def _log_warn_once(message: str): + if message in _WARNED_MOE_MESSAGES: + return + _WARNED_MOE_MESSAGES.add(message) + print(message) def _load_cached_moe_utils_module(): @@ -268,6 +276,25 @@ def select_moe_backend(): return "native_torch" +def _is_float8_tensor(tensor: Optional[torch.Tensor]) -> bool: + return tensor is not None and getattr(tensor, "dtype", None) == torch.float8_e4m3fn + + +def _get_fp8_dequant_target_dtype(hidden_states: torch.Tensor) -> torch.dtype: + if hidden_states.dtype in (torch.float32, torch.float16, torch.bfloat16): + return hidden_states.dtype + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + + +def _build_active_expert_grouping(num_tokens_per_expert: torch.Tensor): + active_expert_ids = torch.nonzero(num_tokens_per_expert > 0, as_tuple=False).squeeze(-1) + active_counts = num_tokens_per_expert.index_select(0, active_expert_ids).to(torch.int32) + offsets = torch.cumsum(active_counts, dim=0, dtype=torch.int32) + return active_expert_ids, active_counts, offsets + + def forward_moe_backend( self, hidden_states: torch.Tensor, @@ -477,6 +504,169 @@ def _get_base_weight(param): return param +def _get_base_weight_and_quant_state(param): + """Get base weight together with any attached FP8 quant metadata.""" + # This Unsloth Zoo code section is licensed under AGPL3 + + base_layer = param + while hasattr(base_layer, "base_layer"): + base_layer = base_layer.base_layer + + if hasattr(base_layer, "get_param"): + weight = base_layer.get_param() + elif hasattr(base_layer, "weight"): + weight = base_layer.weight + else: + weight = base_layer + + quant_state = getattr(weight, "quant_state", None) + if quant_state is None: + quant_state = getattr(base_layer, "weight_scale_inv", None) + if quant_state is None: + quant_state = getattr(base_layer, "weight_scale", None) + + block_size = getattr(base_layer, "block_size", None) + if block_size is not None: + try: + weight.block_size = block_size + except Exception: + pass + if quant_state is not None: + try: + quant_state.block_size = block_size + except Exception: + pass + + return weight, quant_state + + +def _get_moe_weight_and_quant_state(experts_module, param_name: str): + """Get expert weight plus FP8 quant metadata, including stacked-parameter attrs.""" + # This Unsloth Zoo code section is licensed under AGPL3 + + param = getattr(experts_module, param_name) + weight, quant_state = _get_base_weight_and_quant_state(param) + + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_weight_scale_inv", None) + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_weight_scale", None) + + block_size = getattr(param, "block_size", None) + if block_size is None: + block_size = getattr(experts_module, f"{param_name}_block_size", None) + if block_size is not None: + try: + weight.block_size = block_size + except Exception: + pass + if quant_state is not None: + try: + quant_state.block_size = block_size + except Exception: + pass + + return weight, quant_state + + +def _slice_fp8_quant_state(weight: torch.Tensor, quant_state, expert_idx: int): + """Best-effort extraction of per-expert FP8 quant metadata.""" + if quant_state is None or not isinstance(quant_state, torch.Tensor): + return quant_state + + if quant_state.numel() == 1: + sliced = quant_state + elif quant_state.shape[0] == weight.shape[0]: + sliced = quant_state[expert_idx] + elif quant_state.shape[0] % weight.shape[0] == 0: + chunk_size = quant_state.shape[0] // weight.shape[0] + start = expert_idx * chunk_size + end = start + chunk_size + sliced = quant_state[start:end] + else: + return None + + block_size = getattr(weight, "block_size", None) or getattr(quant_state, "block_size", None) + if block_size is not None: + try: + sliced.block_size = block_size + except Exception: + pass + return sliced + + +def _dequantize_expert_slice( + expert_weight: torch.Tensor, + expert_quant_state, + target_dtype: torch.dtype, +) -> Optional[torch.Tensor]: + """Dequantize one expert tensor into grouped_mm-compatible precision.""" + if expert_weight.dtype != torch.float8_e4m3fn: + return expert_weight.to(target_dtype) + + if expert_quant_state is None: + return expert_weight.to(target_dtype) + + try: + from unsloth.kernels.fp8 import weight_dequant + except Exception: + return None + + block_size = getattr(expert_weight, "block_size", None) or getattr(expert_quant_state, "block_size", None) + if block_size is not None: + try: + expert_weight.block_size = block_size + except Exception: + pass + try: + expert_quant_state.block_size = block_size + except Exception: + pass + + return weight_dequant(expert_weight, expert_quant_state, dtype=target_dtype) + + +def _dequantize_active_expert_weights( + weight: torch.Tensor, + quant_state, + active_expert_ids: torch.Tensor, + target_dtype: torch.dtype, + proj_type: str, + hidden_dim: int, + model_type = None, +) -> Optional[torch.Tensor]: + """Dequantize only the routed experts and then preprocess for grouped_mm.""" + if weight.ndim != 3: + return None + + active_slices = [] + block_size = getattr(weight, "block_size", None) + for expert_idx in active_expert_ids.tolist(): + expert_weight = weight[expert_idx].contiguous() + if block_size is not None: + try: + expert_weight.block_size = block_size + except Exception: + pass + expert_quant_state = _slice_fp8_quant_state(weight, quant_state, expert_idx) + expert_dequant = _dequantize_expert_slice(expert_weight, expert_quant_state, target_dtype) + if expert_dequant is None: + return None + active_slices.append(expert_dequant) + + packed_weight = torch.stack(active_slices, dim=0) + return preprocess_weight(packed_weight, proj_type, hidden_dim, model_type) + + +def _moe_uses_fp8_expert_weights(self) -> bool: + if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): + return False + + gate_weight, _ = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_weight, _ = _get_moe_weight_and_quant_state(self, "down_proj") + return _is_float8_tensor(gate_weight) or _is_float8_tensor(down_weight) + + def _get_lora_wrapper_for_param(experts_module, param_name): """ Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj). @@ -768,6 +958,179 @@ def patch_param_wrapper_for_moe(): return False +def _forward_native_grouped_mm_active_dequant( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> Optional[torch.Tensor]: + """ + FP8 compatibility path: dequantize only routed experts, then run grouped_mm. + Falls back to None when the expert quant metadata cannot be interpreted safely. + """ + # This Unsloth Zoo code section is licensed under AGPL3 + + if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): + return None + + is_2d_input = hidden_states.dim() == 2 + if is_2d_input: + sequence_length, hidden_dim = hidden_states.shape + batch_size = 1 + else: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.view(-1, hidden_dim) + + flat_top_k = top_k_index.view(-1) + num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() + sorted_indices = torch.argsort(flat_top_k, stable=True) + token_indices = sorted_indices // top_k_index.shape[-1] + permuted_input = hidden_states[token_indices] + + active_expert_ids, active_counts, offsets = _build_active_expert_grouping(num_tokens_per_expert) + if active_expert_ids.numel() == 0: + return torch.zeros_like(hidden_states) if is_2d_input else hidden_states.new_zeros(batch_size, sequence_length, hidden_dim) + + target_dtype = _get_fp8_dequant_target_dtype(permuted_input) + model_type = getattr(self, "_unsloth_model_type", None) + use_separated_lora = _should_use_separated_lora() + + gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") + gate_up_weight = _dequantize_active_expert_weights( + gate_up_base, + gate_up_quant, + active_expert_ids, + target_dtype, + "gate_up", + hidden_dim, + model_type, + ) + down_weight = _dequantize_active_expert_weights( + down_base, + down_quant, + active_expert_ids, + target_dtype, + "down", + hidden_dim, + model_type, + ) + if gate_up_weight is None or down_weight is None: + return None + + permuted_input = permuted_input.to(target_dtype) + mm1_out = _grouped_mm_with_backward_fix(permuted_input, gate_up_weight.contiguous(), offsets) + + gate_up_lora = None + if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: + gate_up_lora = self._unsloth_lora_gate_up_proj[:3] + elif use_separated_lora and _has_lora_adapters(self.gate_up_proj): + gate_up_lora = _extract_lora_weights( + self.gate_up_proj, num_experts=self.num_experts, experts_module=self + ) + + if gate_up_lora is not None: + first_weight, second_weight, scaling = gate_up_lora + active_expert_ids_device = active_expert_ids.to(first_weight.device) + first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() + second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() + lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets).contiguous() + try: + lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) + except RuntimeError: + lora_delta = torch.empty( + (lora_out.shape[0], second_weight.shape[-1]), + dtype=lora_out.dtype, + device=lora_out.device, + ) + cpu_offsets = offsets.cpu().tolist() + prev_offset = 0 + for i, end in enumerate(cpu_offsets): + if prev_offset < end: + lora_delta[prev_offset:end] = torch.matmul( + lora_out[prev_offset:end], second_weight[i] + ) + prev_offset = end + mm1_out = mm1_out + lora_delta * scaling + + if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: + bias_indices = active_expert_ids.to(self.gate_up_proj_bias.device) + bias_expanded = self.gate_up_proj_bias.index_select(0, bias_indices).repeat_interleave( + active_counts.to(self.gate_up_proj_bias.device), dim=0 + ) + mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) + + if "GptOssExperts" in self.__class__.__name__: + gate = mm1_out[..., ::2] + up = mm1_out[..., 1::2] + limit = getattr(self, "limit", 7.0) + alpha = getattr(self, "alpha", 1.702) + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) + else: + gate, up = mm1_out.chunk(2, dim=-1) + inter = F.silu(gate) * up + + mm2_out = _grouped_mm_with_backward_fix(inter, down_weight.contiguous(), offsets) + + down_lora = None + if getattr(self, "_unsloth_lora_down_proj", None) is not None: + down_lora = self._unsloth_lora_down_proj[:3] + elif use_separated_lora and _has_lora_adapters(self.down_proj): + down_lora = _extract_lora_weights( + self.down_proj, num_experts=self.num_experts, experts_module=self + ) + + if down_lora is not None: + first_weight, second_weight, scaling = down_lora + active_expert_ids_device = active_expert_ids.to(first_weight.device) + first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() + second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() + lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets).contiguous() + try: + lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) + except RuntimeError: + lora_delta = torch.empty( + (lora_out.shape[0], second_weight.shape[-1]), + dtype=lora_out.dtype, + device=lora_out.device, + ) + cpu_offsets = offsets.cpu().tolist() + prev_offset = 0 + for i, end in enumerate(cpu_offsets): + if prev_offset < end: + lora_delta[prev_offset:end] = torch.matmul( + lora_out[prev_offset:end], second_weight[i] + ) + prev_offset = end + mm2_out = mm2_out + lora_delta * scaling + + if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: + bias_indices = active_expert_ids.to(self.down_proj_bias.device) + bias_expanded = self.down_proj_bias.index_select(0, bias_indices).repeat_interleave( + active_counts.to(self.down_proj_bias.device), dim=0 + ) + mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) + + flat_weights = top_k_weights.view(-1) + permuted_weights = flat_weights[sorted_indices] + mm2_out = mm2_out * permuted_weights.unsqueeze(-1) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=input_dtype, + device=hidden_states.device, + ) + final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype)) + + if is_2d_input: + return final_hidden_states + return final_hidden_states.view(batch_size, sequence_length, hidden_dim) + + def forward_native_grouped_mm( self, hidden_states: torch.Tensor, @@ -789,6 +1152,22 @@ def forward_native_grouped_mm( f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend." ) + if _moe_uses_fp8_expert_weights(self): + _log_warn_once( + "Unsloth: MoE grouped_mm detected FP8 expert weights; dequantizing only routed experts " + "to a temporary high-precision grouped_mm buffer." + ) + active_dequant_output = _forward_native_grouped_mm_active_dequant( + self, hidden_states, top_k_index, top_k_weights + ) + if active_dequant_output is not None: + return active_dequant_output + _log_warn_once( + "Unsloth: FP8 expert metadata was insufficient for active grouped_mm dequantization. " + "Falling back to native_torch MoE loop." + ) + return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights) + is_2d_input = hidden_states.dim() == 2 if is_2d_input: sequence_length, hidden_dim = hidden_states.shape From b4bcdef7a42d08350c4e8564f40477f72408a3b6 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 15 Mar 2026 14:49:33 +0000 Subject: [PATCH 2/8] [WIP] cleanup --- unsloth_zoo/temporary_patches/glm4_moe.py | 110 ---- unsloth_zoo/temporary_patches/moe_utils.py | 547 +++++++++++++++++- .../temporary_patches/moe_utils_fp8.py | 148 +++++ 3 files changed, 670 insertions(+), 135 deletions(-) create mode 100644 unsloth_zoo/temporary_patches/moe_utils_fp8.py diff --git a/unsloth_zoo/temporary_patches/glm4_moe.py b/unsloth_zoo/temporary_patches/glm4_moe.py index 8947a70af..c4b0a00e2 100644 --- a/unsloth_zoo/temporary_patches/glm4_moe.py +++ b/unsloth_zoo/temporary_patches/glm4_moe.py @@ -14,9 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import os import torch -import torch.nn as nn from .common import TEMPORARY_PATCHES, torch_compile, UNSLOTH_ENABLE_LOGGING from .utils import patch_function, raise_error, logger from .moe_utils import ( @@ -25,114 +23,6 @@ ) -def maybe_patch_glm4_moe_expert_fp8_scales( - model, - model_name: str, - token = None, - revision = None, -): - """ - GLM-4.7-Flash FP8 Dynamic stores routed expert weights as raw float8 tensors - plus per-expert weight_scale tensors. Transformers currently leaves those - scales as UNEXPECTED keys because Glm4MoeLiteNaiveMoe uses stacked - nn.Parameters instead of Linear modules, so we patch the expert tensors here. - We must preserve the FP8 parameters and attach the scale tensors for runtime - dequantization only when a fallback path needs high-precision weights. - """ - config = getattr(model, "config", None) - if config is None or getattr(config, "model_type", None) != "glm4_moe_lite": - return False - - quantization_config = getattr(config, "quantization_config", None) - if isinstance(quantization_config, dict): - quant_method = quantization_config.get("quant_method", None) - else: - quant_method = getattr(quantization_config, "quant_method", None) - if quant_method != "compressed-tensors": - return False - - inner_model = getattr(model, "model", None) - if inner_model is None or not hasattr(inner_model, "layers"): - return False - - routed_layers = [] - for layer_idx, layer in enumerate(inner_model.layers): - experts = getattr(getattr(layer, "mlp", None), "experts", None) - if experts is None or not hasattr(experts, "gate_up_proj"): - continue - if getattr(experts.gate_up_proj, "dtype", None) == torch.float8_e4m3fn: - routed_layers.append((layer_idx, experts)) - if len(routed_layers) == 0: - return False - - if os.path.isdir(model_name): - safetensors_path = os.path.join(model_name, "model.safetensors") - else: - from huggingface_hub import hf_hub_download - - safetensors_path = hf_hub_download( - repo_id = model_name, - filename = "model.safetensors", - token = token, - revision = revision, - ) - - import safetensors.torch - - with safetensors.torch.safe_open(safetensors_path, framework = "pt") as file: - for layer_idx, experts in routed_layers: - device = experts.gate_up_proj.device - num_experts = experts.gate_up_proj.shape[0] - gate_up_rows = [] - down_rows = [] - gate_up_scales = [] - down_scales = [] - - for expert_idx in range(num_experts): - gate = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" - ) - gate_scale = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale" - ) - up = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" - ) - up_scale = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale" - ) - down = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" - ) - down_scale = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale" - ) - - gate_up_rows.append(torch.cat([gate, up], dim = 0)) - down_rows.append(down) - gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0)) - down_scales.append(down_scale) - - experts.gate_up_proj = nn.Parameter( - torch.stack(gate_up_rows, dim = 0).to(device = device), - requires_grad = experts.gate_up_proj.requires_grad, - ) - experts.down_proj = nn.Parameter( - torch.stack(down_rows, dim = 0).to(device = device), - requires_grad = experts.down_proj.requires_grad, - ) - experts.gate_up_proj_weight_scale = nn.Parameter( - torch.stack(gate_up_scales, dim = 0).to(device = device), - requires_grad = False, - ) - experts.down_proj_weight_scale = nn.Parameter( - torch.stack(down_scales, dim = 0).to(device = device), - requires_grad = False, - ) - - return True - - def patch_glm4_moe(): """ Patches GLM4 MoE to support Split LoRA using grouped GEMM. diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 767be0d95..be4e7314d 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -142,9 +142,11 @@ def _grouped_mm_with_backward_fix( # Global flag to check if grouped GEMM is available _GROUPED_GEMM_AVAILABLE = None _TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm") +_TORCH_SCALED_GROUPED_MM_AVAILABLE = hasattr(torch, "_scaled_grouped_mm") # Check if GPU supports torch._grouped_mm (verified via runtime check) _TORCH_GROUPED_MM_SUPPORTED = None +_TORCH_SCALED_GROUPED_MM_SUPPORTED = None def _check_torch_grouped_mm_supported(): @@ -185,6 +187,54 @@ def _check_torch_grouped_mm_supported(): return _TORCH_GROUPED_MM_SUPPORTED +def _check_torch_scaled_grouped_mm_supported(): + """ + Check if torch._scaled_grouped_mm is actually supported on the current GPU. + """ + global _TORCH_SCALED_GROUPED_MM_SUPPORTED + if _TORCH_SCALED_GROUPED_MM_SUPPORTED is not None: + return _TORCH_SCALED_GROUPED_MM_SUPPORTED + + if not _TORCH_SCALED_GROUPED_MM_AVAILABLE: + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return False + + if not torch.cuda.is_available(): + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return False + major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + if major != 9: + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return False + + try: + device = torch.cuda.current_device() + x = torch.randn((16, 16), device=device, dtype=torch.bfloat16) + w_hp = torch.randn((1, 16, 16), device=device, dtype=torch.bfloat16) + + x_fp8, x_scale = _manual_fp8_rowwise_quantize(x) + w_fp8, w_scale = _manual_fp8_rowwise_quantize(w_hp.view(-1, w_hp.shape[-1])) + w_fp8 = w_fp8.view_as(w_hp) + w_fp8 = w_fp8.transpose(-2, -1).contiguous().transpose(-2, -1) + w_scale = w_scale.view(w_hp.shape[0], w_hp.shape[1]) + offs = torch.tensor([16], device=device, dtype=torch.int32) + + torch._scaled_grouped_mm( + x_fp8.contiguous(), + w_fp8, + x_scale.contiguous(), + w_scale.contiguous(), + offs=offs, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + _TORCH_SCALED_GROUPED_MM_SUPPORTED = True + except Exception: + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + + return _TORCH_SCALED_GROUPED_MM_SUPPORTED + + _TRITON_ALLOCATOR_INITIALIZED = False _PERSISTENT_BUFFER = None @@ -569,6 +619,60 @@ def _get_moe_weight_and_quant_state(experts_module, param_name: str): return weight, quant_state +def _get_moe_weight_and_quant_info(experts_module, param_name: str): + """Get expert weight, quant metadata, and whether it is scale or scale_inv.""" + param = getattr(experts_module, param_name) + + base_layer = param + while hasattr(base_layer, "base_layer"): + base_layer = base_layer.base_layer + + if hasattr(base_layer, "get_param"): + weight = base_layer.get_param() + elif hasattr(base_layer, "weight"): + weight = base_layer.weight + else: + weight = base_layer + + quant_state = getattr(weight, "quant_state", None) + quant_kind = "quant_state" if quant_state is not None else None + if quant_state is None: + quant_state = getattr(base_layer, "weight_scale_inv", None) + if quant_state is not None: + quant_kind = "weight_scale_inv" + else: + quant_state = getattr(base_layer, "weight_scale", None) + if quant_state is not None: + quant_kind = "weight_scale" + + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_weight_scale_inv", None) + if quant_state is not None: + quant_kind = "weight_scale_inv" + else: + quant_state = getattr(experts_module, f"{param_name}_weight_scale", None) + if quant_state is not None: + quant_kind = "weight_scale" + + block_size = getattr(param, "block_size", None) + if block_size is None: + block_size = getattr(experts_module, f"{param_name}_block_size", None) + if block_size is None: + block_size = getattr(base_layer, "block_size", None) + if block_size is not None: + try: + weight.block_size = block_size + except Exception: + pass + if quant_state is not None: + try: + quant_state.block_size = block_size + except Exception: + pass + + return weight, quant_state, quant_kind + + def _slice_fp8_quant_state(weight: torch.Tensor, quant_state, expert_idx: int): """Best-effort extraction of per-expert FP8 quant metadata.""" if quant_state is None or not isinstance(quant_state, torch.Tensor): @@ -623,9 +727,153 @@ def _dequantize_expert_slice( except Exception: pass + if isinstance(expert_quant_state, torch.Tensor) and expert_quant_state.ndim == 1: + expert_quant_state = expert_quant_state.view(-1, 1) + return weight_dequant(expert_weight, expert_quant_state, dtype=target_dtype) +def _dequantize_full_expert_weights( + weight: torch.Tensor, + quant_state, + target_dtype: torch.dtype, + proj_type: str, + hidden_dim: int, + model_type = None, +) -> Optional[torch.Tensor]: + if weight.ndim != 3: + return None + expert_ids = torch.arange(weight.shape[0], device=weight.device, dtype=torch.long) + return _dequantize_active_expert_weights( + weight, + quant_state, + expert_ids, + target_dtype, + proj_type, + hidden_dim, + model_type, + ) + + +def _make_grouped_mm_rhs_column_major(weight: torch.Tensor) -> torch.Tensor: + return weight.transpose(-2, -1).contiguous().transpose(-2, -1) + + +def _extract_scaled_grouped_mm_weight_scale( + original_weight: torch.Tensor, + processed_weight: torch.Tensor, + quant_state, + quant_kind: Optional[str], +) -> Optional[torch.Tensor]: + if quant_state is None or not isinstance(quant_state, torch.Tensor): + return None + if quant_kind == "quant_state": + return None + + if getattr(original_weight, "block_size", None) is not None: + return None + if getattr(quant_state, "block_size", None) is not None: + return None + + scale = quant_state + if scale.ndim == 0: + scale = scale.view(1, 1).expand(processed_weight.shape[0], processed_weight.shape[-1]) + elif scale.ndim == 1: + if scale.shape[0] != processed_weight.shape[-1]: + return None + scale = scale.view(1, -1).expand(processed_weight.shape[0], -1) + elif scale.ndim == 2: + pass + elif scale.ndim == 3: + if scale.shape[1] == 1: + scale = scale.squeeze(1) + elif scale.shape[2] == 1: + scale = scale.squeeze(2) + else: + return None + else: + return None + + if scale.ndim != 2: + return None + if scale.shape[0] != processed_weight.shape[0]: + return None + if scale.shape[1] != processed_weight.shape[-1]: + return None + + scale = scale.to(torch.float32) + if quant_kind == "weight_scale_inv": + scale = scale.reciprocal() + return scale.contiguous() + + +def _prepare_scaled_grouped_mm_weight( + experts_module, + param_name: str, + proj_type: str, + hidden_dim: int, + model_type = None, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + weight, quant_state, quant_kind = _get_moe_weight_and_quant_info(experts_module, param_name) + if not _is_float8_tensor(weight): + return None + + processed_weight = preprocess_weight(weight, proj_type, hidden_dim, model_type) + processed_weight = _make_grouped_mm_rhs_column_major(processed_weight) + scale = _extract_scaled_grouped_mm_weight_scale( + weight, + processed_weight, + quant_state, + quant_kind, + ) + if scale is None: + return None + + return processed_weight, scale + + +def _quantize_inputs_for_scaled_grouped_mm( + inputs: torch.Tensor, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + return _manual_fp8_rowwise_quantize(inputs) + + +def _scaled_grouped_mm_with_backward_fix( + inputs: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + offsets: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + return torch._scaled_grouped_mm( + inputs, + weight, + input_scale, + weight_scale, + offs=offsets, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + +def _manual_fp8_rowwise_quantize( + inputs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize each row independently into float8 and return decode scales for + torch._scaled_grouped_mm. This avoids relying on optional torchao/fbgemm + quantization helpers that are not always available in runtime environments. + """ + inputs_fp32 = inputs.to(torch.float32) + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + amax = inputs_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + quant_scale = (max_fp8 / amax) + quantized = (inputs_fp32 * quant_scale).to(torch.float8_e4m3fn) + decode_scale = quant_scale.reciprocal().squeeze(-1).to(torch.float32) + return quantized.contiguous(), decode_scale.contiguous() + + def _dequantize_active_expert_weights( weight: torch.Tensor, quant_state, @@ -1131,6 +1379,193 @@ def _forward_native_grouped_mm_active_dequant( return final_hidden_states.view(batch_size, sequence_length, hidden_dim) +def _forward_native_grouped_mm_scaled_fp8( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> Optional[torch.Tensor]: + """ + FP8 fast path: use torch._scaled_grouped_mm directly when the expert scales + are compatible with a simple rowwise/tensorwise grouped matmul. + """ + if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): + return None + if not _check_torch_scaled_grouped_mm_supported(): + return None + + is_2d_input = hidden_states.dim() == 2 + if is_2d_input: + sequence_length, hidden_dim = hidden_states.shape + batch_size = 1 + else: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.view(-1, hidden_dim) + + flat_top_k = top_k_index.view(-1) + num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() + sorted_indices = torch.argsort(flat_top_k, stable=True) + token_indices = sorted_indices // top_k_index.shape[-1] + permuted_input = hidden_states[token_indices] + + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + use_separated_lora = _should_use_separated_lora() + model_type = getattr(self, "_unsloth_model_type", None) + + gate_up_prepared = _prepare_scaled_grouped_mm_weight( + self, "gate_up_proj", "gate_up", hidden_dim, model_type + ) + down_prepared = _prepare_scaled_grouped_mm_weight( + self, "down_proj", "down", hidden_dim, model_type + ) + if gate_up_prepared is None or down_prepared is None: + return None + + gate_up_weight, gate_up_scale = gate_up_prepared + down_weight, down_scale = down_prepared + + quantized_input = _quantize_inputs_for_scaled_grouped_mm( + permuted_input.to(_get_fp8_dequant_target_dtype(permuted_input)) + ) + if quantized_input is None: + return None + permuted_input_fp8, input_scale = quantized_input + + try: + mm1_out = _scaled_grouped_mm_with_backward_fix( + permuted_input_fp8, + gate_up_weight, + input_scale, + gate_up_scale, + offsets, + out_dtype=_get_fp8_dequant_target_dtype(permuted_input), + ) + except RuntimeError: + return None + + gate_up_lora = None + if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: + gate_up_lora = self._unsloth_lora_gate_up_proj[:3] + elif use_separated_lora and _has_lora_adapters(self.gate_up_proj): + gate_up_lora = _extract_lora_weights( + self.gate_up_proj, num_experts=self.num_experts, experts_module=self + ) + + if gate_up_lora is not None: + first_weight, second_weight, scaling = gate_up_lora + first_weight = first_weight.to(mm1_out.dtype).contiguous() + second_weight = second_weight.to(mm1_out.dtype).contiguous() + lora_out = _grouped_mm_with_backward_fix( + permuted_input.to(mm1_out.dtype), first_weight, offsets + ).contiguous() + try: + lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) + except RuntimeError: + lora_delta = torch.empty( + (lora_out.shape[0], second_weight.shape[-1]), + dtype=lora_out.dtype, + device=lora_out.device, + ) + cpu_offsets = offsets.cpu().tolist() + prev_offset = 0 + for i, end in enumerate(cpu_offsets): + if prev_offset < end: + lora_delta[prev_offset:end] = torch.matmul( + lora_out[prev_offset:end], second_weight[i] + ) + prev_offset = end + mm1_out = mm1_out + lora_delta * scaling + + if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: + bias_expanded = self.gate_up_proj_bias.repeat_interleave( + num_tokens_per_expert.to(self.gate_up_proj_bias.device), dim=0 + ) + mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) + + if "GptOssExperts" in self.__class__.__name__: + gate = mm1_out[..., ::2] + up = mm1_out[..., 1::2] + limit = getattr(self, "limit", 7.0) + alpha = getattr(self, "alpha", 1.702) + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) + else: + gate, up = mm1_out.chunk(2, dim=-1) + inter = F.silu(gate) * up + + inter_quantized = _quantize_inputs_for_scaled_grouped_mm(inter) + if inter_quantized is None: + return None + inter_fp8, inter_scale = inter_quantized + + try: + mm2_out = _scaled_grouped_mm_with_backward_fix( + inter_fp8, + down_weight, + inter_scale, + down_scale, + offsets, + out_dtype=mm1_out.dtype, + ) + except RuntimeError: + return None + + down_lora = None + if getattr(self, "_unsloth_lora_down_proj", None) is not None: + down_lora = self._unsloth_lora_down_proj[:3] + elif use_separated_lora and _has_lora_adapters(self.down_proj): + down_lora = _extract_lora_weights( + self.down_proj, num_experts=self.num_experts, experts_module=self + ) + + if down_lora is not None: + first_weight, second_weight, scaling = down_lora + first_weight = first_weight.to(inter.dtype).contiguous() + second_weight = second_weight.to(inter.dtype).contiguous() + lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets).contiguous() + try: + lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) + except RuntimeError: + lora_delta = torch.empty( + (lora_out.shape[0], second_weight.shape[-1]), + dtype=lora_out.dtype, + device=lora_out.device, + ) + cpu_offsets = offsets.cpu().tolist() + prev_offset = 0 + for i, end in enumerate(cpu_offsets): + if prev_offset < end: + lora_delta[prev_offset:end] = torch.matmul( + lora_out[prev_offset:end], second_weight[i] + ) + prev_offset = end + mm2_out = mm2_out + lora_delta * scaling + + if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: + bias_expanded = self.down_proj_bias.repeat_interleave( + num_tokens_per_expert.to(self.down_proj_bias.device), dim=0 + ).to(mm2_out.device) + mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) + + flat_weights = top_k_weights.view(-1) + permuted_weights = flat_weights[sorted_indices] + mm2_out = mm2_out * permuted_weights.unsqueeze(-1) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=input_dtype, + device=hidden_states.device, + ) + final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype)) + + if is_2d_input: + return final_hidden_states + return final_hidden_states.view(batch_size, sequence_length, hidden_dim) + + def forward_native_grouped_mm( self, hidden_states: torch.Tensor, @@ -1153,6 +1588,15 @@ def forward_native_grouped_mm( ) if _moe_uses_fp8_expert_weights(self): + scaled_output = _forward_native_grouped_mm_scaled_fp8( + self, hidden_states, top_k_index, top_k_weights + ) + if scaled_output is not None: + _log_warn_once( + "Unsloth: MoE grouped_mm detected compatible FP8 expert weights; using torch._scaled_grouped_mm." + ) + return scaled_output + _log_warn_once( "Unsloth: MoE grouped_mm detected FP8 expert weights; dequantizing only routed experts " "to a temporary high-precision grouped_mm buffer." @@ -1452,33 +1896,21 @@ def forward_native_grouped_mm( return final_hidden_states.view(batch_size, sequence_length, hidden_dim) -def forward_triton_grouped_gemm( +def _forward_triton_grouped_gemm_impl( self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + gate_up_proj: Optional[torch.Tensor] = None, + down_proj: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Grouped GEMM MoE forward pass using Triton kernels. - Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin). - """ - # This Unsloth Zoo code section is licensed under AGPL3 + gate_up_proj = self.gate_up_proj if gate_up_proj is None else gate_up_proj + down_proj = self.down_proj if down_proj is None else down_proj - # Import grouped GEMM interface - from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm - # Import autotune cache + from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels - # Helper to check TMA support - assumes helper function or just check directly - # In original: it was a cached closure. Here we can use _supports_tma() directly - - # nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this! - # For now, let's attach it to self if possible, or use a global usage - # Attaching to self is cleaner: self._unsloth_moe_configs - - # Create expert mask and find which experts have tokens - if not hasattr(self, "_unsloth_moe_configs"): self._unsloth_moe_configs = None @@ -1503,7 +1935,7 @@ def forward_triton_grouped_gemm( # Cache model dimensions and kernel configs on first call if self._unsloth_moe_configs is None: - intermediate_dim = self.gate_up_proj.shape[1] // 2 + intermediate_dim = gate_up_proj.shape[1] // 2 # Autotune first GEMM gemm1_configs = get_or_autotune_moe_kernels( @@ -1541,10 +1973,10 @@ def forward_triton_grouped_gemm( ) offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32) - if self.gate_up_proj.shape[-1] == hidden_dim: - w1 = self.gate_up_proj + if gate_up_proj.shape[-1] == hidden_dim: + w1 = gate_up_proj else: - w1 = self.gate_up_proj.transpose(-2, -1).contiguous() + w1 = gate_up_proj.transpose(-2, -1).contiguous() # First grouped GEMM: gate_up projection first_gemm_output = grouped_gemm( @@ -1579,10 +2011,10 @@ def forward_triton_grouped_gemm( ): down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts) - if self.down_proj.shape[-1] == intermediate.shape[-1]: - w2 = self.down_proj + if down_proj.shape[-1] == intermediate.shape[-1]: + w2 = down_proj else: - w2 = self.down_proj.transpose(-2, -1).contiguous() + w2 = down_proj.transpose(-2, -1).contiguous() second_gemm_output = grouped_gemm( X=intermediate, @@ -1636,6 +2068,71 @@ def forward_triton_grouped_gemm( return final_hidden_states +def forward_triton_grouped_gemm( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + """ + Grouped GEMM MoE forward pass using Triton kernels. + Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin). + """ + # This Unsloth Zoo code section is licensed under AGPL3 + + if _moe_uses_fp8_expert_weights(self): + if _check_torch_grouped_mm_supported(): + _log_warn_once( + "Unsloth: MoE Triton backend detected FP8 expert weights; routing through grouped_mm FP8 handling." + ) + return forward_native_grouped_mm( + self, hidden_states, top_k_index, top_k_weights + ) + + target_dtype = _get_fp8_dequant_target_dtype(hidden_states) + hidden_dim = hidden_states.shape[-1] + model_type = getattr(self, "_unsloth_model_type", None) + gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") + gate_up_weight = _dequantize_full_expert_weights( + gate_up_base, + gate_up_quant, + target_dtype, + "gate_up", + hidden_dim, + model_type, + ) + down_weight = _dequantize_full_expert_weights( + down_base, + down_quant, + target_dtype, + "down", + hidden_dim, + model_type, + ) + if gate_up_weight is None or down_weight is None: + _log_warn_once( + "Unsloth: FP8 expert metadata was insufficient for Triton dequant fallback. Falling back to native_torch MoE loop." + ) + return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights) + + _log_warn_once( + "Unsloth: MoE Triton backend detected FP8 expert weights; dequantizing experts on the fly for Triton grouped GEMM." + ) + return _forward_triton_grouped_gemm_impl( + self, + hidden_states.to(target_dtype), + top_k_index, + top_k_weights, + gate_up_proj=gate_up_weight, + down_proj=down_weight, + ) + + return _forward_triton_grouped_gemm_impl( + self, hidden_states, top_k_index, top_k_weights + ) + + @torch.compiler.disable def forward_native_moe_loop( self, diff --git a/unsloth_zoo/temporary_patches/moe_utils_fp8.py b/unsloth_zoo/temporary_patches/moe_utils_fp8.py new file mode 100644 index 000000000..2022f2e33 --- /dev/null +++ b/unsloth_zoo/temporary_patches/moe_utils_fp8.py @@ -0,0 +1,148 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import os +import torch +import torch.nn as nn + + +def _maybe_patch_glm4_stacked_moe_fp8_scales( + model, + model_name: str, + token = None, + revision = None, +): + """ + Attach missing FP8 scale tensors to stacked routed-expert parameters. + + This currently handles GLM4-MoE Lite style experts where transformers loads + the float8 expert weights but leaves the per-expert weight_scale tensors as + unexpected keys because the experts are stacked nn.Parameters rather than + Linear modules. + """ + config = getattr(model, "config", None) + if config is None or getattr(config, "model_type", None) != "glm4_moe_lite": + return False + + quantization_config = getattr(config, "quantization_config", None) + if isinstance(quantization_config, dict): + quant_method = quantization_config.get("quant_method", None) + else: + quant_method = getattr(quantization_config, "quant_method", None) + if quant_method != "compressed-tensors": + return False + + inner_model = getattr(model, "model", None) + if inner_model is None or not hasattr(inner_model, "layers"): + return False + + routed_layers = [] + for layer_idx, layer in enumerate(inner_model.layers): + experts = getattr(getattr(layer, "mlp", None), "experts", None) + if experts is None or not hasattr(experts, "gate_up_proj"): + continue + if getattr(experts.gate_up_proj, "dtype", None) == torch.float8_e4m3fn: + routed_layers.append((layer_idx, experts)) + if len(routed_layers) == 0: + return False + + if os.path.isdir(model_name): + safetensors_path = os.path.join(model_name, "model.safetensors") + else: + from huggingface_hub import hf_hub_download + + safetensors_path = hf_hub_download( + repo_id = model_name, + filename = "model.safetensors", + token = token, + revision = revision, + ) + + import safetensors.torch + + with safetensors.torch.safe_open(safetensors_path, framework = "pt") as file: + for layer_idx, experts in routed_layers: + device = experts.gate_up_proj.device + num_experts = experts.gate_up_proj.shape[0] + gate_up_rows = [] + down_rows = [] + gate_up_scales = [] + down_scales = [] + + for expert_idx in range(num_experts): + gate = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" + ) + gate_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale" + ) + up = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" + ) + up_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale" + ) + down = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale" + ) + + gate_up_rows.append(torch.cat([gate, up], dim = 0)) + down_rows.append(down) + gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0)) + down_scales.append(down_scale) + + experts.gate_up_proj = nn.Parameter( + torch.stack(gate_up_rows, dim = 0).to(device = device), + requires_grad = experts.gate_up_proj.requires_grad, + ) + experts.down_proj = nn.Parameter( + torch.stack(down_rows, dim = 0).to(device = device), + requires_grad = experts.down_proj.requires_grad, + ) + experts.gate_up_proj_weight_scale = nn.Parameter( + torch.stack(gate_up_scales, dim = 0).to(device = device), + requires_grad = False, + ) + experts.down_proj_weight_scale = nn.Parameter( + torch.stack(down_scales, dim = 0).to(device = device), + requires_grad = False, + ) + + return True + + +def maybe_patch_stacked_moe_expert_fp8_scales( + model, + model_name: str, + token = None, + revision = None, +): + """ + Best-effort hook for prequantized FP8 MoE checkpoints that use stacked expert + parameters and need extra runtime quant metadata attached after loading. + + This is intentionally generic at the callsite. Model-specific handlers can + be added here as new stacked-FP8 MoE formats appear. + """ + return _maybe_patch_glm4_stacked_moe_fp8_scales( + model, + model_name, + token = token, + revision = revision, + ) From 64491aa59df444354e300d33ac9ef41e5175e97f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 16 Mar 2026 05:51:49 +0000 Subject: [PATCH 3/8] patch for fp8 moe to use unsloth kerenls --- unsloth_zoo/temporary_patches/moe_utils.py | 37 ++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index be4e7314d..110a6733a 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -601,6 +601,10 @@ def _get_moe_weight_and_quant_state(experts_module, param_name: str): quant_state = getattr(experts_module, f"{param_name}_weight_scale_inv", None) if quant_state is None: quant_state = getattr(experts_module, f"{param_name}_weight_scale", None) + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_scale_inv", None) + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_scale", None) block_size = getattr(param, "block_size", None) if block_size is None: @@ -653,6 +657,14 @@ def _get_moe_weight_and_quant_info(experts_module, param_name: str): quant_state = getattr(experts_module, f"{param_name}_weight_scale", None) if quant_state is not None: quant_kind = "weight_scale" + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_scale_inv", None) + if quant_state is not None: + quant_kind = "weight_scale_inv" + else: + quant_state = getattr(experts_module, f"{param_name}_scale", None) + if quant_state is not None: + quant_kind = "weight_scale" block_size = getattr(param, "block_size", None) if block_size is None: @@ -713,6 +725,7 @@ def _dequantize_expert_slice( try: from unsloth.kernels.fp8 import weight_dequant + import triton except Exception: return None @@ -727,6 +740,30 @@ def _dequantize_expert_slice( except Exception: pass + # Match the handling used by unsloth.kernels.fp8.FP8BlockQuantLinear: + # some FP8 checkpoints store block scales transposed relative to the + # expert weight layout. In that case, transpose the scale grid before + # dequantizing so the recovered bf16/fp16 weights are numerically sane. + if ( + isinstance(expert_quant_state, torch.Tensor) + and expert_quant_state.ndim == 2 + and len(block_size) == 2 + ): + m, n = expert_weight.shape + p, q = expert_quant_state.shape + if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q: + if ( + triton.cdiv(m, block_size[0]) == q + and triton.cdiv(n, block_size[1]) == p + ): + expert_quant_state = expert_quant_state.T.contiguous() + try: + expert_quant_state.block_size = block_size + except Exception: + pass + else: + return None + if isinstance(expert_quant_state, torch.Tensor) and expert_quant_state.ndim == 1: expert_quant_state = expert_quant_state.view(-1, 1) From c3195d302e7620adff4dd69e79255534a6c81d30 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 16 Mar 2026 05:59:24 +0000 Subject: [PATCH 4/8] cleanup --- unsloth_zoo/temporary_patches/moe_utils.py | 23 ---------------------- 1 file changed, 23 deletions(-) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 110a6733a..d75a76eba 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -555,9 +555,6 @@ def _get_base_weight(param): def _get_base_weight_and_quant_state(param): - """Get base weight together with any attached FP8 quant metadata.""" - # This Unsloth Zoo code section is licensed under AGPL3 - base_layer = param while hasattr(base_layer, "base_layer"): base_layer = base_layer.base_layer @@ -591,9 +588,6 @@ def _get_base_weight_and_quant_state(param): def _get_moe_weight_and_quant_state(experts_module, param_name: str): - """Get expert weight plus FP8 quant metadata, including stacked-parameter attrs.""" - # This Unsloth Zoo code section is licensed under AGPL3 - param = getattr(experts_module, param_name) weight, quant_state = _get_base_weight_and_quant_state(param) @@ -624,7 +618,6 @@ def _get_moe_weight_and_quant_state(experts_module, param_name: str): def _get_moe_weight_and_quant_info(experts_module, param_name: str): - """Get expert weight, quant metadata, and whether it is scale or scale_inv.""" param = getattr(experts_module, param_name) base_layer = param @@ -686,7 +679,6 @@ def _get_moe_weight_and_quant_info(experts_module, param_name: str): def _slice_fp8_quant_state(weight: torch.Tensor, quant_state, expert_idx: int): - """Best-effort extraction of per-expert FP8 quant metadata.""" if quant_state is None or not isinstance(quant_state, torch.Tensor): return quant_state @@ -716,7 +708,6 @@ def _dequantize_expert_slice( expert_quant_state, target_dtype: torch.dtype, ) -> Optional[torch.Tensor]: - """Dequantize one expert tensor into grouped_mm-compatible precision.""" if expert_weight.dtype != torch.float8_e4m3fn: return expert_weight.to(target_dtype) @@ -740,10 +731,6 @@ def _dequantize_expert_slice( except Exception: pass - # Match the handling used by unsloth.kernels.fp8.FP8BlockQuantLinear: - # some FP8 checkpoints store block scales transposed relative to the - # expert weight layout. In that case, transpose the scale grid before - # dequantizing so the recovered bf16/fp16 weights are numerically sane. if ( isinstance(expert_quant_state, torch.Tensor) and expert_quant_state.ndim == 2 @@ -1189,17 +1176,12 @@ def _patched_param_wrapper_forward( lora_data = _extract_lora_from_wrapper(self) if lora_data is not None and param_name: - # Store LoRA data on the EXPERTS MODULE (not base_layer) - # e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj lora_attr = f"_unsloth_lora_{param_name}" setattr(experts_module, lora_attr, lora_data) try: - # Call IMMEDIATE base_layer to preserve wrapper chain - # (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts) result = immediate_base_layer(x, *args, **kwargs) finally: - # Clean up if param_name: lora_attr = f"_unsloth_lora_{param_name}" if hasattr(experts_module, lora_attr): @@ -1701,14 +1683,9 @@ def forward_native_grouped_mm( # Get model type for preprocessing (if registered) model_type = getattr(self, "_unsloth_model_type", None) - # Handle different weight shapes using preprocessor - # torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view. w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type) - # Base forward: X @ W mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets) - # Add separated LoRA contribution: + ((X @ first) @ second) * scaling - # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling) if gate_up_lora is not None: first_weight, second_weight, scaling = gate_up_lora From 9a0192447ae808352638b939e3b3b6ca106d6d3a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 16 Mar 2026 12:44:30 +0000 Subject: [PATCH 5/8] Cleanup and refactor --- unsloth_zoo/patching_utils.py | 3 + unsloth_zoo/temporary_patches/glm4_moe.py | 10 +- unsloth_zoo/temporary_patches/moe_utils.py | 1035 +++-------------- .../temporary_patches/moe_utils_fp8.py | 627 +++++++++- 4 files changed, 775 insertions(+), 900 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index ba77af906..350c1bd4f 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -30,6 +30,7 @@ from .compiler import UNSLOTH_COMPILE_LOCATION from .utils import _get_dtype, Version from .hf_utils import dtype_from_config, set_dtype_in_config, HAS_TORCH_DTYPE +from .temporary_patches.moe_utils_fp8 import maybe_patch_stacked_moe_expert_fp8_scales # Also disable compiling on bitsandbytes def patch_compiling_bitsandbytes(): @@ -396,6 +397,8 @@ def __fix_dtype(config): # string when trying to save the config or serialize it patch_to_dict() + maybe_patch_stacked_moe_expert_fp8_scales(model) + # Check all params and patch! for name, module in model.named_modules(): if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): diff --git a/unsloth_zoo/temporary_patches/glm4_moe.py b/unsloth_zoo/temporary_patches/glm4_moe.py index c4b0a00e2..c5c47fd1f 100644 --- a/unsloth_zoo/temporary_patches/glm4_moe.py +++ b/unsloth_zoo/temporary_patches/glm4_moe.py @@ -21,8 +21,6 @@ patch_param_wrapper_for_moe, get_forward_moe_backend, ) - - def patch_glm4_moe(): """ Patches GLM4 MoE to support Split LoRA using grouped GEMM. @@ -123,8 +121,12 @@ def moe_block_forward(self, hidden_states) -> torch.Tensor: return hidden_states + shared_output # Apply patches - patch_function(Glm4MoeLiteNaiveMoe, "forward", get_forward_moe_backend()) - patch_function(Glm4MoeLiteMoE, "forward", moe_block_forward) + # Recent transformers wraps the expert forward with use_experts_implementation + # and drops some annotations, so strict signature matching rejects the patch. + # For GLM4 we want to bypass that wrapper entirely and route into Unsloth's + # MoE backend on purpose. + patch_function(Glm4MoeLiteNaiveMoe, "forward", get_forward_moe_backend(), force = True) + patch_function(Glm4MoeLiteMoE, "forward", moe_block_forward, force = True) if UNSLOTH_ENABLE_LOGGING: logger.info("Unsloth: Patched GLM4 MoE for Split LoRA support.") diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index d75a76eba..b0912b8f7 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -66,17 +66,14 @@ def install_to_cache(source_path, destination_filename=None): install_to_cache(__file__, "moe_utils.py") +_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +_MOE_UTILS_FP8_PATH = os.path.join(_CURRENT_DIR, "moe_utils_fp8.py") +if os.path.isfile(_MOE_UTILS_FP8_PATH): + install_to_cache(_MOE_UTILS_FP8_PATH, "moe_utils_fp8.py") _CACHED_FORWARD_MOE_BACKEND = None _CACHED_MOE_UTILS_MODULE = None -_WARNED_MOE_MESSAGES = set() - - -def _log_warn_once(message: str): - if message in _WARNED_MOE_MESSAGES: - return - _WARNED_MOE_MESSAGES.add(message) - print(message) +_CACHED_MOE_UTILS_FP8_MODULE = None def _load_cached_moe_utils_module(): @@ -88,7 +85,7 @@ def _load_cached_moe_utils_module(): return None try: - module_name = "unsloth_cached_moe_utils" + module_name = "unsloth_zoo.temporary_patches._cached_moe_utils" module = sys.modules.get(module_name, None) if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file: _CACHED_MOE_UTILS_MODULE = module @@ -98,6 +95,7 @@ def _load_cached_moe_utils_module(): if spec is None or spec.loader is None: return None module = importlib.util.module_from_spec(spec) + module.__package__ = "unsloth_zoo.temporary_patches" sys.modules[module_name] = module spec.loader.exec_module(module) _CACHED_MOE_UTILS_MODULE = module @@ -106,18 +104,52 @@ def _load_cached_moe_utils_module(): return None +def _load_cached_moe_utils_fp8_module(): + global _CACHED_MOE_UTILS_FP8_MODULE + + cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils_fp8.py")) + current_file = os.path.abspath(_MOE_UTILS_FP8_PATH) + if not os.path.isfile(cache_file) or cache_file == current_file: + return None + + try: + module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_fp8" + module = sys.modules.get(module_name, None) + if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file: + _CACHED_MOE_UTILS_FP8_MODULE = module + return module + + spec = importlib.util.spec_from_file_location(module_name, cache_file) + if spec is None or spec.loader is None: + return None + module = importlib.util.module_from_spec(spec) + module.__package__ = "unsloth_zoo.temporary_patches" + sys.modules[module_name] = module + spec.loader.exec_module(module) + _CACHED_MOE_UTILS_FP8_MODULE = module + return module + except Exception: + return None + + def get_forward_moe_backend(): """ Resolve forward_moe_backend from the compiled cache copy when available. Falls back to the local module definition. """ global _CACHED_FORWARD_MOE_BACKEND + fp8_module = _load_cached_moe_utils_fp8_module() + if fp8_module is not None and hasattr(fp8_module, "forward_moe_backend_fp8"): + _CACHED_FORWARD_MOE_BACKEND = fp8_module.forward_moe_backend_fp8 + return _CACHED_FORWARD_MOE_BACKEND + module = _load_cached_moe_utils_module() if module is not None and hasattr(module, "forward_moe_backend"): _CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend return _CACHED_FORWARD_MOE_BACKEND - _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend + from .moe_utils_fp8 import forward_moe_backend_fp8 + _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend_fp8 return _CACHED_FORWARD_MOE_BACKEND # ============================================================================ @@ -142,11 +174,8 @@ def _grouped_mm_with_backward_fix( # Global flag to check if grouped GEMM is available _GROUPED_GEMM_AVAILABLE = None _TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm") -_TORCH_SCALED_GROUPED_MM_AVAILABLE = hasattr(torch, "_scaled_grouped_mm") - # Check if GPU supports torch._grouped_mm (verified via runtime check) _TORCH_GROUPED_MM_SUPPORTED = None -_TORCH_SCALED_GROUPED_MM_SUPPORTED = None def _check_torch_grouped_mm_supported(): @@ -187,56 +216,17 @@ def _check_torch_grouped_mm_supported(): return _TORCH_GROUPED_MM_SUPPORTED -def _check_torch_scaled_grouped_mm_supported(): - """ - Check if torch._scaled_grouped_mm is actually supported on the current GPU. - """ - global _TORCH_SCALED_GROUPED_MM_SUPPORTED - if _TORCH_SCALED_GROUPED_MM_SUPPORTED is not None: - return _TORCH_SCALED_GROUPED_MM_SUPPORTED +_TRITON_ALLOCATOR_INITIALIZED = False +_PERSISTENT_BUFFER = None - if not _TORCH_SCALED_GROUPED_MM_AVAILABLE: - _TORCH_SCALED_GROUPED_MM_SUPPORTED = False - return False - - if not torch.cuda.is_available(): - _TORCH_SCALED_GROUPED_MM_SUPPORTED = False - return False - major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) - if major != 9: - _TORCH_SCALED_GROUPED_MM_SUPPORTED = False - return False +def _try_attach_block_size(tensor_like, block_size) -> None: + if block_size is None or tensor_like is None: + return try: - device = torch.cuda.current_device() - x = torch.randn((16, 16), device=device, dtype=torch.bfloat16) - w_hp = torch.randn((1, 16, 16), device=device, dtype=torch.bfloat16) - - x_fp8, x_scale = _manual_fp8_rowwise_quantize(x) - w_fp8, w_scale = _manual_fp8_rowwise_quantize(w_hp.view(-1, w_hp.shape[-1])) - w_fp8 = w_fp8.view_as(w_hp) - w_fp8 = w_fp8.transpose(-2, -1).contiguous().transpose(-2, -1) - w_scale = w_scale.view(w_hp.shape[0], w_hp.shape[1]) - offs = torch.tensor([16], device=device, dtype=torch.int32) - - torch._scaled_grouped_mm( - x_fp8.contiguous(), - w_fp8, - x_scale.contiguous(), - w_scale.contiguous(), - offs=offs, - out_dtype=torch.bfloat16, - use_fast_accum=True, - ) - _TORCH_SCALED_GROUPED_MM_SUPPORTED = True - except Exception: - _TORCH_SCALED_GROUPED_MM_SUPPORTED = False - - return _TORCH_SCALED_GROUPED_MM_SUPPORTED - - -_TRITON_ALLOCATOR_INITIALIZED = False -_PERSISTENT_BUFFER = None + tensor_like.block_size = block_size + except (AttributeError, RuntimeError): + pass def _init_triton_allocator(): @@ -326,25 +316,6 @@ def select_moe_backend(): return "native_torch" -def _is_float8_tensor(tensor: Optional[torch.Tensor]) -> bool: - return tensor is not None and getattr(tensor, "dtype", None) == torch.float8_e4m3fn - - -def _get_fp8_dequant_target_dtype(hidden_states: torch.Tensor) -> torch.dtype: - if hidden_states.dtype in (torch.float32, torch.float16, torch.bfloat16): - return hidden_states.dtype - if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): - return torch.bfloat16 - return torch.float16 - - -def _build_active_expert_grouping(num_tokens_per_expert: torch.Tensor): - active_expert_ids = torch.nonzero(num_tokens_per_expert > 0, as_tuple=False).squeeze(-1) - active_counts = num_tokens_per_expert.index_select(0, active_expert_ids).to(torch.int32) - offsets = torch.cumsum(active_counts, dim=0, dtype=torch.int32) - return active_expert_ids, active_counts, offsets - - def forward_moe_backend( self, hidden_states: torch.Tensor, @@ -356,6 +327,14 @@ def forward_moe_backend( Centralizes backend selection to keep model-specific patches minimal. """ # This Unsloth Zoo code section is licensed under AGPL3 + try: + from .moe_utils_fp8 import _moe_uses_fp8_expert_weights, forward_moe_backend_fp8 + if _moe_uses_fp8_expert_weights(self): + return forward_moe_backend_fp8( + self, hidden_states, top_k_index, top_k_weights + ) + except Exception: + pass backend = select_moe_backend() if backend == "grouped_mm": @@ -574,15 +553,9 @@ def _get_base_weight_and_quant_state(param): block_size = getattr(base_layer, "block_size", None) if block_size is not None: - try: - weight.block_size = block_size - except Exception: - pass + _try_attach_block_size(weight, block_size) if quant_state is not None: - try: - quant_state.block_size = block_size - except Exception: - pass + _try_attach_block_size(quant_state, block_size) return weight, quant_state @@ -604,339 +577,84 @@ def _get_moe_weight_and_quant_state(experts_module, param_name: str): if block_size is None: block_size = getattr(experts_module, f"{param_name}_block_size", None) if block_size is not None: - try: - weight.block_size = block_size - except Exception: - pass + _try_attach_block_size(weight, block_size) if quant_state is not None: - try: - quant_state.block_size = block_size - except Exception: - pass + _try_attach_block_size(quant_state, block_size) return weight, quant_state -def _get_moe_weight_and_quant_info(experts_module, param_name: str): - param = getattr(experts_module, param_name) - - base_layer = param - while hasattr(base_layer, "base_layer"): - base_layer = base_layer.base_layer - - if hasattr(base_layer, "get_param"): - weight = base_layer.get_param() - elif hasattr(base_layer, "weight"): - weight = base_layer.weight - else: - weight = base_layer - - quant_state = getattr(weight, "quant_state", None) - quant_kind = "quant_state" if quant_state is not None else None - if quant_state is None: - quant_state = getattr(base_layer, "weight_scale_inv", None) - if quant_state is not None: - quant_kind = "weight_scale_inv" - else: - quant_state = getattr(base_layer, "weight_scale", None) - if quant_state is not None: - quant_kind = "weight_scale" - - if quant_state is None: - quant_state = getattr(experts_module, f"{param_name}_weight_scale_inv", None) - if quant_state is not None: - quant_kind = "weight_scale_inv" - else: - quant_state = getattr(experts_module, f"{param_name}_weight_scale", None) - if quant_state is not None: - quant_kind = "weight_scale" - if quant_state is None: - quant_state = getattr(experts_module, f"{param_name}_scale_inv", None) - if quant_state is not None: - quant_kind = "weight_scale_inv" - else: - quant_state = getattr(experts_module, f"{param_name}_scale", None) - if quant_state is not None: - quant_kind = "weight_scale" - - block_size = getattr(param, "block_size", None) - if block_size is None: - block_size = getattr(experts_module, f"{param_name}_block_size", None) - if block_size is None: - block_size = getattr(base_layer, "block_size", None) - if block_size is not None: - try: - weight.block_size = block_size - except Exception: - pass - if quant_state is not None: - try: - quant_state.block_size = block_size - except Exception: - pass - - return weight, quant_state, quant_kind - - -def _slice_fp8_quant_state(weight: torch.Tensor, quant_state, expert_idx: int): - if quant_state is None or not isinstance(quant_state, torch.Tensor): - return quant_state - - if quant_state.numel() == 1: - sliced = quant_state - elif quant_state.shape[0] == weight.shape[0]: - sliced = quant_state[expert_idx] - elif quant_state.shape[0] % weight.shape[0] == 0: - chunk_size = quant_state.shape[0] // weight.shape[0] - start = expert_idx * chunk_size - end = start + chunk_size - sliced = quant_state[start:end] - else: - return None +def _get_grouped_lora(self, projection_name: str, cache_attr: str, use_separated_lora: bool): + cached_lora = getattr(self, cache_attr, None) + if cached_lora is not None: + return cached_lora[:3] - block_size = getattr(weight, "block_size", None) or getattr(quant_state, "block_size", None) - if block_size is not None: - try: - sliced.block_size = block_size - except Exception: - pass - return sliced + projection = getattr(self, projection_name, None) + if use_separated_lora and projection is not None and _has_lora_adapters(projection): + return _extract_lora_weights( + projection, num_experts=self.num_experts, experts_module=self + ) + return None -def _dequantize_expert_slice( - expert_weight: torch.Tensor, - expert_quant_state, +def _apply_grouped_lora( + grouped_input: torch.Tensor, + lora_weights, + offsets: torch.Tensor, target_dtype: torch.dtype, -) -> Optional[torch.Tensor]: - if expert_weight.dtype != torch.float8_e4m3fn: - return expert_weight.to(target_dtype) - - if expert_quant_state is None: - return expert_weight.to(target_dtype) - + active_expert_ids: Optional[torch.Tensor] = None, +) -> torch.Tensor: + first_weight, second_weight, scaling = lora_weights + if active_expert_ids is not None: + active_expert_ids = active_expert_ids.to(first_weight.device) + first_weight = first_weight.index_select(0, active_expert_ids) + second_weight = second_weight.index_select(0, active_expert_ids) + + first_weight = first_weight.to(target_dtype).contiguous() + second_weight = second_weight.to(target_dtype).contiguous() + lora_out = _grouped_mm_with_backward_fix( + grouped_input.to(target_dtype), first_weight, offsets + ).contiguous() try: - from unsloth.kernels.fp8 import weight_dequant - import triton - except Exception: - return None - - block_size = getattr(expert_weight, "block_size", None) or getattr(expert_quant_state, "block_size", None) - if block_size is not None: - try: - expert_weight.block_size = block_size - except Exception: - pass - try: - expert_quant_state.block_size = block_size - except Exception: - pass - - if ( - isinstance(expert_quant_state, torch.Tensor) - and expert_quant_state.ndim == 2 - and len(block_size) == 2 - ): - m, n = expert_weight.shape - p, q = expert_quant_state.shape - if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q: - if ( - triton.cdiv(m, block_size[0]) == q - and triton.cdiv(n, block_size[1]) == p - ): - expert_quant_state = expert_quant_state.T.contiguous() - try: - expert_quant_state.block_size = block_size - except Exception: - pass - else: - return None - - if isinstance(expert_quant_state, torch.Tensor) and expert_quant_state.ndim == 1: - expert_quant_state = expert_quant_state.view(-1, 1) - - return weight_dequant(expert_weight, expert_quant_state, dtype=target_dtype) - - -def _dequantize_full_expert_weights( - weight: torch.Tensor, - quant_state, - target_dtype: torch.dtype, - proj_type: str, - hidden_dim: int, - model_type = None, -) -> Optional[torch.Tensor]: - if weight.ndim != 3: - return None - expert_ids = torch.arange(weight.shape[0], device=weight.device, dtype=torch.long) - return _dequantize_active_expert_weights( - weight, - quant_state, - expert_ids, - target_dtype, - proj_type, - hidden_dim, - model_type, - ) - - -def _make_grouped_mm_rhs_column_major(weight: torch.Tensor) -> torch.Tensor: - return weight.transpose(-2, -1).contiguous().transpose(-2, -1) - - -def _extract_scaled_grouped_mm_weight_scale( - original_weight: torch.Tensor, - processed_weight: torch.Tensor, - quant_state, - quant_kind: Optional[str], -) -> Optional[torch.Tensor]: - if quant_state is None or not isinstance(quant_state, torch.Tensor): - return None - if quant_kind == "quant_state": - return None - - if getattr(original_weight, "block_size", None) is not None: - return None - if getattr(quant_state, "block_size", None) is not None: - return None - - scale = quant_state - if scale.ndim == 0: - scale = scale.view(1, 1).expand(processed_weight.shape[0], processed_weight.shape[-1]) - elif scale.ndim == 1: - if scale.shape[0] != processed_weight.shape[-1]: - return None - scale = scale.view(1, -1).expand(processed_weight.shape[0], -1) - elif scale.ndim == 2: - pass - elif scale.ndim == 3: - if scale.shape[1] == 1: - scale = scale.squeeze(1) - elif scale.shape[2] == 1: - scale = scale.squeeze(2) + if second_weight.shape[-1] % 8 != 0: + pad_size = 8 - (second_weight.shape[-1] % 8) + second_weight_padded = F.pad(second_weight, (0, pad_size)).contiguous() + lora_delta = _grouped_mm_with_backward_fix( + lora_out, second_weight_padded, offsets + ) + lora_delta = lora_delta[:, :-pad_size] else: - return None - else: - return None - - if scale.ndim != 2: - return None - if scale.shape[0] != processed_weight.shape[0]: - return None - if scale.shape[1] != processed_weight.shape[-1]: - return None - - scale = scale.to(torch.float32) - if quant_kind == "weight_scale_inv": - scale = scale.reciprocal() - return scale.contiguous() - - -def _prepare_scaled_grouped_mm_weight( - experts_module, - param_name: str, - proj_type: str, - hidden_dim: int, - model_type = None, -) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: - weight, quant_state, quant_kind = _get_moe_weight_and_quant_info(experts_module, param_name) - if not _is_float8_tensor(weight): - return None - - processed_weight = preprocess_weight(weight, proj_type, hidden_dim, model_type) - processed_weight = _make_grouped_mm_rhs_column_major(processed_weight) - scale = _extract_scaled_grouped_mm_weight_scale( - weight, - processed_weight, - quant_state, - quant_kind, - ) - if scale is None: - return None - - return processed_weight, scale - - -def _quantize_inputs_for_scaled_grouped_mm( - inputs: torch.Tensor, -) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: - return _manual_fp8_rowwise_quantize(inputs) + lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) + except RuntimeError: + lora_delta = torch.empty( + (lora_out.shape[0], second_weight.shape[-1]), + dtype=lora_out.dtype, + device=lora_out.device, + ) + cpu_offsets = offsets.cpu().tolist() + prev_offset = 0 + for i, end in enumerate(cpu_offsets): + if prev_offset < end: + lora_delta[prev_offset:end] = torch.matmul( + lora_out[prev_offset:end], second_weight[i] + ) + prev_offset = end + return lora_delta * scaling -def _scaled_grouped_mm_with_backward_fix( - inputs: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - offsets: torch.Tensor, - out_dtype: torch.dtype, +def _expand_grouped_bias( + bias: torch.Tensor, + counts: torch.Tensor, + expert_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return torch._scaled_grouped_mm( - inputs, - weight, - input_scale, - weight_scale, - offs=offsets, - out_dtype=out_dtype, - use_fast_accum=True, - ) - - -def _manual_fp8_rowwise_quantize( - inputs: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize each row independently into float8 and return decode scales for - torch._scaled_grouped_mm. This avoids relying on optional torchao/fbgemm - quantization helpers that are not always available in runtime environments. - """ - inputs_fp32 = inputs.to(torch.float32) - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - amax = inputs_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) - quant_scale = (max_fp8 / amax) - quantized = (inputs_fp32 * quant_scale).to(torch.float8_e4m3fn) - decode_scale = quant_scale.reciprocal().squeeze(-1).to(torch.float32) - return quantized.contiguous(), decode_scale.contiguous() - - -def _dequantize_active_expert_weights( - weight: torch.Tensor, - quant_state, - active_expert_ids: torch.Tensor, - target_dtype: torch.dtype, - proj_type: str, - hidden_dim: int, - model_type = None, -) -> Optional[torch.Tensor]: - """Dequantize only the routed experts and then preprocess for grouped_mm.""" - if weight.ndim != 3: - return None - - active_slices = [] - block_size = getattr(weight, "block_size", None) - for expert_idx in active_expert_ids.tolist(): - expert_weight = weight[expert_idx].contiguous() - if block_size is not None: - try: - expert_weight.block_size = block_size - except Exception: - pass - expert_quant_state = _slice_fp8_quant_state(weight, quant_state, expert_idx) - expert_dequant = _dequantize_expert_slice(expert_weight, expert_quant_state, target_dtype) - if expert_dequant is None: - return None - active_slices.append(expert_dequant) - - packed_weight = torch.stack(active_slices, dim=0) - return preprocess_weight(packed_weight, proj_type, hidden_dim, model_type) - - -def _moe_uses_fp8_expert_weights(self) -> bool: - if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): - return False - - gate_weight, _ = _get_moe_weight_and_quant_state(self, "gate_up_proj") - down_weight, _ = _get_moe_weight_and_quant_state(self, "down_proj") - return _is_float8_tensor(gate_weight) or _is_float8_tensor(down_weight) + if expert_ids is None: + expanded = bias.repeat_interleave(counts.to(bias.device), dim=0) + else: + expert_ids = expert_ids.to(bias.device) + expanded = bias.index_select(0, expert_ids).repeat_interleave( + counts.to(bias.device), dim=0 + ) + return expanded def _get_lora_wrapper_for_param(experts_module, param_name): @@ -1134,27 +852,18 @@ def _patched_param_wrapper_forward( """ # This Unsloth Zoo code section is licensed under AGPL3 - # CRITICAL: Use self.base_layer for forward call (immediate parent) - # NOT self.get_base_layer() which recursively traverses to deepest layer! - # The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts immediate_base_layer = self.base_layer - # For storing LoRA data, we DO need the actual experts module - # Use get_base_layer() to find it (recursive traversal is correct here) experts_module = self.get_base_layer() use_separated = _should_use_separated_lora() param_name = getattr(self, "parameter_name", None) - # Check if this is an MoE experts module that should use separated LoRA if ( use_separated and param_name in ("gate_up_proj", "down_proj") and _is_moe_experts_module(experts_module) ): - # MoE experts: bypass PEFT's _activate_lora, use separated computation - - # Check adapter state if self.disable_adapters: if self.merged: self.unmerge() @@ -1163,7 +872,6 @@ def _patched_param_wrapper_forward( if self.merged: return immediate_base_layer(x, *args, **kwargs) - # Ensure wrapper.num_experts is set for LoRA weight reshaping if not hasattr(self, "num_experts"): if hasattr(experts_module, "num_experts"): self.num_experts = experts_module.num_experts @@ -1172,7 +880,6 @@ def _patched_param_wrapper_forward( if hasattr(p, "shape") and len(p.shape) >= 1: self.num_experts = p.shape[0] - # Extract LoRA for this specific parameter lora_data = _extract_lora_from_wrapper(self) if lora_data is not None and param_name: @@ -1189,7 +896,6 @@ def _patched_param_wrapper_forward( return result - # Non-MoE: use original PEFT forward with _activate_lora return _original_param_wrapper_forward(self, x, *args, **kwargs) @@ -1225,366 +931,6 @@ def patch_param_wrapper_for_moe(): return False -def _forward_native_grouped_mm_active_dequant( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> Optional[torch.Tensor]: - """ - FP8 compatibility path: dequantize only routed experts, then run grouped_mm. - Falls back to None when the expert quant metadata cannot be interpreted safely. - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): - return None - - is_2d_input = hidden_states.dim() == 2 - if is_2d_input: - sequence_length, hidden_dim = hidden_states.shape - batch_size = 1 - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - - input_dtype = hidden_states.dtype - hidden_states = hidden_states.view(-1, hidden_dim) - - flat_top_k = top_k_index.view(-1) - num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() - sorted_indices = torch.argsort(flat_top_k, stable=True) - token_indices = sorted_indices // top_k_index.shape[-1] - permuted_input = hidden_states[token_indices] - - active_expert_ids, active_counts, offsets = _build_active_expert_grouping(num_tokens_per_expert) - if active_expert_ids.numel() == 0: - return torch.zeros_like(hidden_states) if is_2d_input else hidden_states.new_zeros(batch_size, sequence_length, hidden_dim) - - target_dtype = _get_fp8_dequant_target_dtype(permuted_input) - model_type = getattr(self, "_unsloth_model_type", None) - use_separated_lora = _should_use_separated_lora() - - gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") - down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") - gate_up_weight = _dequantize_active_expert_weights( - gate_up_base, - gate_up_quant, - active_expert_ids, - target_dtype, - "gate_up", - hidden_dim, - model_type, - ) - down_weight = _dequantize_active_expert_weights( - down_base, - down_quant, - active_expert_ids, - target_dtype, - "down", - hidden_dim, - model_type, - ) - if gate_up_weight is None or down_weight is None: - return None - - permuted_input = permuted_input.to(target_dtype) - mm1_out = _grouped_mm_with_backward_fix(permuted_input, gate_up_weight.contiguous(), offsets) - - gate_up_lora = None - if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: - gate_up_lora = self._unsloth_lora_gate_up_proj[:3] - elif use_separated_lora and _has_lora_adapters(self.gate_up_proj): - gate_up_lora = _extract_lora_weights( - self.gate_up_proj, num_experts=self.num_experts, experts_module=self - ) - - if gate_up_lora is not None: - first_weight, second_weight, scaling = gate_up_lora - active_expert_ids_device = active_expert_ids.to(first_weight.device) - first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() - second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() - lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets).contiguous() - try: - lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) - except RuntimeError: - lora_delta = torch.empty( - (lora_out.shape[0], second_weight.shape[-1]), - dtype=lora_out.dtype, - device=lora_out.device, - ) - cpu_offsets = offsets.cpu().tolist() - prev_offset = 0 - for i, end in enumerate(cpu_offsets): - if prev_offset < end: - lora_delta[prev_offset:end] = torch.matmul( - lora_out[prev_offset:end], second_weight[i] - ) - prev_offset = end - mm1_out = mm1_out + lora_delta * scaling - - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - bias_indices = active_expert_ids.to(self.gate_up_proj_bias.device) - bias_expanded = self.gate_up_proj_bias.index_select(0, bias_indices).repeat_interleave( - active_counts.to(self.gate_up_proj_bias.device), dim=0 - ) - mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) - - if "GptOssExperts" in self.__class__.__name__: - gate = mm1_out[..., ::2] - up = mm1_out[..., 1::2] - limit = getattr(self, "limit", 7.0) - alpha = getattr(self, "alpha", 1.702) - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) - else: - gate, up = mm1_out.chunk(2, dim=-1) - inter = F.silu(gate) * up - - mm2_out = _grouped_mm_with_backward_fix(inter, down_weight.contiguous(), offsets) - - down_lora = None - if getattr(self, "_unsloth_lora_down_proj", None) is not None: - down_lora = self._unsloth_lora_down_proj[:3] - elif use_separated_lora and _has_lora_adapters(self.down_proj): - down_lora = _extract_lora_weights( - self.down_proj, num_experts=self.num_experts, experts_module=self - ) - - if down_lora is not None: - first_weight, second_weight, scaling = down_lora - active_expert_ids_device = active_expert_ids.to(first_weight.device) - first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() - second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() - lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets).contiguous() - try: - lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) - except RuntimeError: - lora_delta = torch.empty( - (lora_out.shape[0], second_weight.shape[-1]), - dtype=lora_out.dtype, - device=lora_out.device, - ) - cpu_offsets = offsets.cpu().tolist() - prev_offset = 0 - for i, end in enumerate(cpu_offsets): - if prev_offset < end: - lora_delta[prev_offset:end] = torch.matmul( - lora_out[prev_offset:end], second_weight[i] - ) - prev_offset = end - mm2_out = mm2_out + lora_delta * scaling - - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - bias_indices = active_expert_ids.to(self.down_proj_bias.device) - bias_expanded = self.down_proj_bias.index_select(0, bias_indices).repeat_interleave( - active_counts.to(self.down_proj_bias.device), dim=0 - ) - mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) - - flat_weights = top_k_weights.view(-1) - permuted_weights = flat_weights[sorted_indices] - mm2_out = mm2_out * permuted_weights.unsqueeze(-1) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=input_dtype, - device=hidden_states.device, - ) - final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype)) - - if is_2d_input: - return final_hidden_states - return final_hidden_states.view(batch_size, sequence_length, hidden_dim) - - -def _forward_native_grouped_mm_scaled_fp8( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> Optional[torch.Tensor]: - """ - FP8 fast path: use torch._scaled_grouped_mm directly when the expert scales - are compatible with a simple rowwise/tensorwise grouped matmul. - """ - if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): - return None - if not _check_torch_scaled_grouped_mm_supported(): - return None - - is_2d_input = hidden_states.dim() == 2 - if is_2d_input: - sequence_length, hidden_dim = hidden_states.shape - batch_size = 1 - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - - input_dtype = hidden_states.dtype - hidden_states = hidden_states.view(-1, hidden_dim) - - flat_top_k = top_k_index.view(-1) - num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() - sorted_indices = torch.argsort(flat_top_k, stable=True) - token_indices = sorted_indices // top_k_index.shape[-1] - permuted_input = hidden_states[token_indices] - - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - use_separated_lora = _should_use_separated_lora() - model_type = getattr(self, "_unsloth_model_type", None) - - gate_up_prepared = _prepare_scaled_grouped_mm_weight( - self, "gate_up_proj", "gate_up", hidden_dim, model_type - ) - down_prepared = _prepare_scaled_grouped_mm_weight( - self, "down_proj", "down", hidden_dim, model_type - ) - if gate_up_prepared is None or down_prepared is None: - return None - - gate_up_weight, gate_up_scale = gate_up_prepared - down_weight, down_scale = down_prepared - - quantized_input = _quantize_inputs_for_scaled_grouped_mm( - permuted_input.to(_get_fp8_dequant_target_dtype(permuted_input)) - ) - if quantized_input is None: - return None - permuted_input_fp8, input_scale = quantized_input - - try: - mm1_out = _scaled_grouped_mm_with_backward_fix( - permuted_input_fp8, - gate_up_weight, - input_scale, - gate_up_scale, - offsets, - out_dtype=_get_fp8_dequant_target_dtype(permuted_input), - ) - except RuntimeError: - return None - - gate_up_lora = None - if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: - gate_up_lora = self._unsloth_lora_gate_up_proj[:3] - elif use_separated_lora and _has_lora_adapters(self.gate_up_proj): - gate_up_lora = _extract_lora_weights( - self.gate_up_proj, num_experts=self.num_experts, experts_module=self - ) - - if gate_up_lora is not None: - first_weight, second_weight, scaling = gate_up_lora - first_weight = first_weight.to(mm1_out.dtype).contiguous() - second_weight = second_weight.to(mm1_out.dtype).contiguous() - lora_out = _grouped_mm_with_backward_fix( - permuted_input.to(mm1_out.dtype), first_weight, offsets - ).contiguous() - try: - lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) - except RuntimeError: - lora_delta = torch.empty( - (lora_out.shape[0], second_weight.shape[-1]), - dtype=lora_out.dtype, - device=lora_out.device, - ) - cpu_offsets = offsets.cpu().tolist() - prev_offset = 0 - for i, end in enumerate(cpu_offsets): - if prev_offset < end: - lora_delta[prev_offset:end] = torch.matmul( - lora_out[prev_offset:end], second_weight[i] - ) - prev_offset = end - mm1_out = mm1_out + lora_delta * scaling - - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - bias_expanded = self.gate_up_proj_bias.repeat_interleave( - num_tokens_per_expert.to(self.gate_up_proj_bias.device), dim=0 - ) - mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) - - if "GptOssExperts" in self.__class__.__name__: - gate = mm1_out[..., ::2] - up = mm1_out[..., 1::2] - limit = getattr(self, "limit", 7.0) - alpha = getattr(self, "alpha", 1.702) - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) - else: - gate, up = mm1_out.chunk(2, dim=-1) - inter = F.silu(gate) * up - - inter_quantized = _quantize_inputs_for_scaled_grouped_mm(inter) - if inter_quantized is None: - return None - inter_fp8, inter_scale = inter_quantized - - try: - mm2_out = _scaled_grouped_mm_with_backward_fix( - inter_fp8, - down_weight, - inter_scale, - down_scale, - offsets, - out_dtype=mm1_out.dtype, - ) - except RuntimeError: - return None - - down_lora = None - if getattr(self, "_unsloth_lora_down_proj", None) is not None: - down_lora = self._unsloth_lora_down_proj[:3] - elif use_separated_lora and _has_lora_adapters(self.down_proj): - down_lora = _extract_lora_weights( - self.down_proj, num_experts=self.num_experts, experts_module=self - ) - - if down_lora is not None: - first_weight, second_weight, scaling = down_lora - first_weight = first_weight.to(inter.dtype).contiguous() - second_weight = second_weight.to(inter.dtype).contiguous() - lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets).contiguous() - try: - lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) - except RuntimeError: - lora_delta = torch.empty( - (lora_out.shape[0], second_weight.shape[-1]), - dtype=lora_out.dtype, - device=lora_out.device, - ) - cpu_offsets = offsets.cpu().tolist() - prev_offset = 0 - for i, end in enumerate(cpu_offsets): - if prev_offset < end: - lora_delta[prev_offset:end] = torch.matmul( - lora_out[prev_offset:end], second_weight[i] - ) - prev_offset = end - mm2_out = mm2_out + lora_delta * scaling - - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - bias_expanded = self.down_proj_bias.repeat_interleave( - num_tokens_per_expert.to(self.down_proj_bias.device), dim=0 - ).to(mm2_out.device) - mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) - - flat_weights = top_k_weights.view(-1) - permuted_weights = flat_weights[sorted_indices] - mm2_out = mm2_out * permuted_weights.unsqueeze(-1) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=input_dtype, - device=hidden_states.device, - ) - final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype)) - - if is_2d_input: - return final_hidden_states - return final_hidden_states.view(batch_size, sequence_length, hidden_dim) - - def forward_native_grouped_mm( self, hidden_states: torch.Tensor, @@ -1606,31 +952,6 @@ def forward_native_grouped_mm( f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend." ) - if _moe_uses_fp8_expert_weights(self): - scaled_output = _forward_native_grouped_mm_scaled_fp8( - self, hidden_states, top_k_index, top_k_weights - ) - if scaled_output is not None: - _log_warn_once( - "Unsloth: MoE grouped_mm detected compatible FP8 expert weights; using torch._scaled_grouped_mm." - ) - return scaled_output - - _log_warn_once( - "Unsloth: MoE grouped_mm detected FP8 expert weights; dequantizing only routed experts " - "to a temporary high-precision grouped_mm buffer." - ) - active_dequant_output = _forward_native_grouped_mm_active_dequant( - self, hidden_states, top_k_index, top_k_weights - ) - if active_dequant_output is not None: - return active_dequant_output - _log_warn_once( - "Unsloth: FP8 expert metadata was insufficient for active grouped_mm dequantization. " - "Falling back to native_torch MoE loop." - ) - return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights) - is_2d_input = hidden_states.dim() == 2 if is_2d_input: sequence_length, hidden_dim = hidden_states.shape @@ -1683,9 +1004,14 @@ def forward_native_grouped_mm( # Get model type for preprocessing (if registered) model_type = getattr(self, "_unsloth_model_type", None) + # Handle different weight shapes using preprocessor + # torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view. w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type) + # Base forward: X @ W mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets) + # Add separated LoRA contribution: + ((X @ first) @ second) * scaling + # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling) if gate_up_lora is not None: first_weight, second_weight, scaling = gate_up_lora @@ -1910,21 +1236,33 @@ def forward_native_grouped_mm( return final_hidden_states.view(batch_size, sequence_length, hidden_dim) -def _forward_triton_grouped_gemm_impl( +def forward_triton_grouped_gemm( self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - gate_up_proj: Optional[torch.Tensor] = None, - down_proj: Optional[torch.Tensor] = None, ) -> torch.Tensor: - gate_up_proj = self.gate_up_proj if gate_up_proj is None else gate_up_proj - down_proj = self.down_proj if down_proj is None else down_proj - + """ + Grouped GEMM MoE forward pass using Triton kernels. + Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin). + """ + # This Unsloth Zoo code section is licensed under AGPL3 + # Import grouped GEMM interface from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm + + # Import autotune cache from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels + # Helper to check TMA support - assumes helper function or just check directly + # In original: it was a cached closure. Here we can use _supports_tma() directly + + # nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this! + # For now, let's attach it to self if possible, or use a global usage + # Attaching to self is cleaner: self._unsloth_moe_configs + + # Create expert mask and find which experts have tokens + if not hasattr(self, "_unsloth_moe_configs"): self._unsloth_moe_configs = None @@ -1949,7 +1287,7 @@ def _forward_triton_grouped_gemm_impl( # Cache model dimensions and kernel configs on first call if self._unsloth_moe_configs is None: - intermediate_dim = gate_up_proj.shape[1] // 2 + intermediate_dim = self.gate_up_proj.shape[1] // 2 # Autotune first GEMM gemm1_configs = get_or_autotune_moe_kernels( @@ -1987,10 +1325,10 @@ def _forward_triton_grouped_gemm_impl( ) offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32) - if gate_up_proj.shape[-1] == hidden_dim: - w1 = gate_up_proj + if self.gate_up_proj.shape[-1] == hidden_dim: + w1 = self.gate_up_proj else: - w1 = gate_up_proj.transpose(-2, -1).contiguous() + w1 = self.gate_up_proj.transpose(-2, -1).contiguous() # First grouped GEMM: gate_up projection first_gemm_output = grouped_gemm( @@ -2025,10 +1363,10 @@ def _forward_triton_grouped_gemm_impl( ): down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts) - if down_proj.shape[-1] == intermediate.shape[-1]: - w2 = down_proj + if self.down_proj.shape[-1] == intermediate.shape[-1]: + w2 = self.down_proj else: - w2 = down_proj.transpose(-2, -1).contiguous() + w2 = self.down_proj.transpose(-2, -1).contiguous() second_gemm_output = grouped_gemm( X=intermediate, @@ -2082,71 +1420,6 @@ def _forward_triton_grouped_gemm_impl( return final_hidden_states -def forward_triton_grouped_gemm( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - """ - Grouped GEMM MoE forward pass using Triton kernels. - Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin). - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - if _moe_uses_fp8_expert_weights(self): - if _check_torch_grouped_mm_supported(): - _log_warn_once( - "Unsloth: MoE Triton backend detected FP8 expert weights; routing through grouped_mm FP8 handling." - ) - return forward_native_grouped_mm( - self, hidden_states, top_k_index, top_k_weights - ) - - target_dtype = _get_fp8_dequant_target_dtype(hidden_states) - hidden_dim = hidden_states.shape[-1] - model_type = getattr(self, "_unsloth_model_type", None) - gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") - down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") - gate_up_weight = _dequantize_full_expert_weights( - gate_up_base, - gate_up_quant, - target_dtype, - "gate_up", - hidden_dim, - model_type, - ) - down_weight = _dequantize_full_expert_weights( - down_base, - down_quant, - target_dtype, - "down", - hidden_dim, - model_type, - ) - if gate_up_weight is None or down_weight is None: - _log_warn_once( - "Unsloth: FP8 expert metadata was insufficient for Triton dequant fallback. Falling back to native_torch MoE loop." - ) - return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights) - - _log_warn_once( - "Unsloth: MoE Triton backend detected FP8 expert weights; dequantizing experts on the fly for Triton grouped GEMM." - ) - return _forward_triton_grouped_gemm_impl( - self, - hidden_states.to(target_dtype), - top_k_index, - top_k_weights, - gate_up_proj=gate_up_weight, - down_proj=down_weight, - ) - - return _forward_triton_grouped_gemm_impl( - self, hidden_states, top_k_index, top_k_weights - ) - - @torch.compiler.disable def forward_native_moe_loop( self, diff --git a/unsloth_zoo/temporary_patches/moe_utils_fp8.py b/unsloth_zoo/temporary_patches/moe_utils_fp8.py index 2022f2e33..c12d1776a 100644 --- a/unsloth_zoo/temporary_patches/moe_utils_fp8.py +++ b/unsloth_zoo/temporary_patches/moe_utils_fp8.py @@ -15,8 +15,15 @@ # along with this program. If not, see . import os +from typing import Optional + import torch import torch.nn as nn +import torch.nn.functional as F + + +_TORCH_SCALED_GROUPED_MM_AVAILABLE = hasattr(torch, "_scaled_grouped_mm") +_TORCH_SCALED_GROUPED_MM_SUPPORTED = None def _maybe_patch_glm4_stacked_moe_fp8_scales( @@ -25,14 +32,6 @@ def _maybe_patch_glm4_stacked_moe_fp8_scales( token = None, revision = None, ): - """ - Attach missing FP8 scale tensors to stacked routed-expert parameters. - - This currently handles GLM4-MoE Lite style experts where transformers loads - the float8 expert weights but leaves the per-expert weight_scale tensors as - unexpected keys because the experts are stacked nn.Parameters rather than - Linear modules. - """ config = getattr(model, "config", None) if config is None or getattr(config, "model_type", None) != "glm4_moe_lite": return False @@ -123,26 +122,624 @@ def _maybe_patch_glm4_stacked_moe_fp8_scales( torch.stack(down_scales, dim = 0).to(device = device), requires_grad = False, ) + experts.gate_up_proj_scale = experts.gate_up_proj_weight_scale + experts.down_proj_scale = experts.down_proj_weight_scale return True def maybe_patch_stacked_moe_expert_fp8_scales( model, - model_name: str, + model_name: Optional[str] = None, token = None, revision = None, ): - """ - Best-effort hook for prequantized FP8 MoE checkpoints that use stacked expert - parameters and need extra runtime quant metadata attached after loading. + if model_name is None: + config = getattr(model, "config", None) + model_name = getattr(config, "_name_or_path", None) + if not model_name: + return False - This is intentionally generic at the callsite. Model-specific handlers can - be added here as new stacked-FP8 MoE formats appear. - """ return _maybe_patch_glm4_stacked_moe_fp8_scales( model, model_name, token = token, revision = revision, ) + + +def _is_float8_tensor(tensor: Optional[torch.Tensor]) -> bool: + return tensor is not None and getattr(tensor, "dtype", None) == torch.float8_e4m3fn + + +def _get_fp8_dequant_target_dtype(hidden_states: torch.Tensor) -> torch.dtype: + if hidden_states.dtype in (torch.float32, torch.float16, torch.bfloat16): + return hidden_states.dtype + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + + +def _log_moe_fp8_backend_once(experts_module, message: str): + from .moe_utils import _log_info + + if getattr(experts_module, "_unsloth_logged_fp8_backend", None) == message: + return + experts_module._unsloth_logged_fp8_backend = message + _log_info(message) + + +def _check_torch_scaled_grouped_mm_supported(): + global _TORCH_SCALED_GROUPED_MM_SUPPORTED + if _TORCH_SCALED_GROUPED_MM_SUPPORTED is not None: + return _TORCH_SCALED_GROUPED_MM_SUPPORTED + + if not _TORCH_SCALED_GROUPED_MM_AVAILABLE: + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return False + if not torch.cuda.is_available(): + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return False + + # The symbol can exist on older GPUs, but probing it on unsupported + # hardware can trigger an async launch failure that poisons the CUDA + # context. Keep the FP8 scaled_grouped_mm path off on pre-Hopper parts. + major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + if major < 9: + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return False + + try: + device = torch.cuda.current_device() + x = torch.randn((16, 16), device=device, dtype=torch.bfloat16) + w_hp = torch.randn((1, 16, 16), device=device, dtype=torch.bfloat16) + x_fp8, x_scale = _manual_fp8_rowwise_quantize(x) + w_fp8, w_scale = _manual_fp8_rowwise_quantize(w_hp.view(-1, w_hp.shape[-1])) + w_fp8 = w_fp8.view_as(w_hp) + w_fp8 = _make_grouped_mm_rhs_column_major(w_fp8) + w_scale = w_scale.view(w_hp.shape[0], w_hp.shape[1]) + offs = torch.tensor([16], device=device, dtype=torch.int32) + torch._scaled_grouped_mm( + x_fp8.contiguous(), + w_fp8, + x_scale.contiguous(), + w_scale.contiguous(), + offs=offs, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + _TORCH_SCALED_GROUPED_MM_SUPPORTED = True + torch.cuda.synchronize(device) + except Exception: + _TORCH_SCALED_GROUPED_MM_SUPPORTED = False + return _TORCH_SCALED_GROUPED_MM_SUPPORTED + + +def _slice_fp8_quant_state(weight: torch.Tensor, quant_state, expert_idx: int): + from .moe_utils import _try_attach_block_size + + if quant_state is None or not isinstance(quant_state, torch.Tensor): + return quant_state + + if quant_state.numel() == 1: + sliced = quant_state + elif quant_state.shape[0] == weight.shape[0]: + sliced = quant_state[expert_idx] + elif quant_state.shape[0] % weight.shape[0] == 0: + chunk_size = quant_state.shape[0] // weight.shape[0] + start = expert_idx * chunk_size + end = start + chunk_size + sliced = quant_state[start:end] + else: + return None + + block_size = getattr(weight, "block_size", None) or getattr(quant_state, "block_size", None) + if block_size is not None: + _try_attach_block_size(sliced, block_size) + return sliced + + +def _dequantize_expert_slice( + expert_weight: torch.Tensor, + expert_quant_state, + target_dtype: torch.dtype, +) -> Optional[torch.Tensor]: + from .moe_utils import _try_attach_block_size + + if expert_weight.dtype != torch.float8_e4m3fn: + return expert_weight.to(target_dtype) + + if expert_quant_state is None: + return None + + try: + from unsloth.kernels.fp8 import weight_dequant + import triton + except Exception: + return None + + block_size = getattr(expert_weight, "block_size", None) or getattr(expert_quant_state, "block_size", None) + if block_size is not None: + _try_attach_block_size(expert_weight, block_size) + _try_attach_block_size(expert_quant_state, block_size) + + if ( + isinstance(expert_quant_state, torch.Tensor) + and expert_quant_state.ndim == 2 + and len(block_size) == 2 + ): + m, n = expert_weight.shape + p, q = expert_quant_state.shape + if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q: + if ( + triton.cdiv(m, block_size[0]) == q + and triton.cdiv(n, block_size[1]) == p + ): + expert_quant_state = expert_quant_state.T.contiguous() + _try_attach_block_size(expert_quant_state, block_size) + else: + return None + + if isinstance(expert_quant_state, torch.Tensor) and expert_quant_state.ndim == 1: + expert_quant_state = expert_quant_state.view(-1, 1) + + return weight_dequant(expert_weight, expert_quant_state, dtype=target_dtype) + + +def _dequantize_full_expert_weights(weight: torch.Tensor, quant_state, target_dtype: torch.dtype): + from .moe_utils import _try_attach_block_size + + if weight.ndim != 3: + return None + + dequantized = [] + block_size = getattr(weight, "block_size", None) + for expert_idx in range(weight.shape[0]): + expert_weight = weight[expert_idx].contiguous() + if block_size is not None: + _try_attach_block_size(expert_weight, block_size) + expert_quant_state = _slice_fp8_quant_state(weight, quant_state, expert_idx) + expert_dequant = _dequantize_expert_slice(expert_weight, expert_quant_state, target_dtype) + if expert_dequant is None: + return None + dequantized.append(expert_dequant) + return torch.stack(dequantized, dim=0) + + +def _make_grouped_mm_rhs_column_major(weight: torch.Tensor) -> torch.Tensor: + return weight.transpose(-2, -1).contiguous().transpose(-2, -1) + + +def _get_moe_weight_and_quant_info(experts_module, param_name: str): + from .moe_utils import _get_base_weight_and_quant_state, _try_attach_block_size + + param = getattr(experts_module, param_name) + weight, quant_state = _get_base_weight_and_quant_state(param) + quant_kind = "quant_state" if getattr(weight, "quant_state", None) is not None else None + + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_weight_scale_inv", None) + if quant_state is not None: + quant_kind = "weight_scale_inv" + else: + quant_state = getattr(experts_module, f"{param_name}_weight_scale", None) + if quant_state is not None: + quant_kind = "weight_scale" + if quant_state is None: + quant_state = getattr(experts_module, f"{param_name}_scale_inv", None) + if quant_state is not None: + quant_kind = "weight_scale_inv" + else: + quant_state = getattr(experts_module, f"{param_name}_scale", None) + if quant_state is not None: + quant_kind = "weight_scale" + + block_size = getattr(param, "block_size", None) + if block_size is None: + block_size = getattr(experts_module, f"{param_name}_block_size", None) + if block_size is not None: + _try_attach_block_size(weight, block_size) + if quant_state is not None: + _try_attach_block_size(quant_state, block_size) + return weight, quant_state, quant_kind + + +def _extract_scaled_grouped_mm_weight_scale(original_weight, processed_weight, quant_state, quant_kind): + if quant_state is None or not isinstance(quant_state, torch.Tensor): + return None + if quant_kind == "quant_state": + return None + if getattr(original_weight, "block_size", None) is not None: + return None + if getattr(quant_state, "block_size", None) is not None: + return None + + scale = quant_state + if scale.ndim == 0: + scale = scale.view(1, 1).expand(processed_weight.shape[0], processed_weight.shape[-1]) + elif scale.ndim == 1: + if scale.shape[0] != processed_weight.shape[-1]: + return None + scale = scale.view(1, -1).expand(processed_weight.shape[0], -1) + elif scale.ndim == 3: + if scale.shape[1] == 1: + scale = scale.squeeze(1) + elif scale.shape[2] == 1: + scale = scale.squeeze(2) + else: + return None + elif scale.ndim != 2: + return None + + if scale.ndim != 2: + return None + if scale.shape[0] != processed_weight.shape[0] or scale.shape[1] != processed_weight.shape[-1]: + return None + + scale = scale.to(torch.float32) + if quant_kind == "weight_scale_inv": + scale = scale.reciprocal() + return scale.contiguous() + + +def _prepare_scaled_grouped_mm_weight(experts_module, param_name: str, proj_type: str, hidden_dim: int, model_type=None): + from .moe_utils import preprocess_weight + + weight, quant_state, quant_kind = _get_moe_weight_and_quant_info(experts_module, param_name) + if not _is_float8_tensor(weight): + return None + processed_weight = preprocess_weight(weight, proj_type, hidden_dim, model_type) + processed_weight = _make_grouped_mm_rhs_column_major(processed_weight) + scale = _extract_scaled_grouped_mm_weight_scale(weight, processed_weight, quant_state, quant_kind) + if scale is None: + return None + return processed_weight, scale + + +def _manual_fp8_rowwise_quantize(inputs: torch.Tensor): + inputs_fp32 = inputs.to(torch.float32) + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + amax = inputs_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + quant_scale = max_fp8 / amax + quantized = (inputs_fp32 * quant_scale).to(torch.float8_e4m3fn) + decode_scale = quant_scale.reciprocal().squeeze(-1).to(torch.float32) + return quantized.contiguous(), decode_scale.contiguous() + + +def _forward_scaled_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights): + from .moe_utils import _get_grouped_lora, _apply_grouped_lora, _expand_grouped_bias + + if not _check_torch_scaled_grouped_mm_supported(): + return None + if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): + return None + + is_2d_input = hidden_states.dim() == 2 + if is_2d_input: + sequence_length, hidden_dim = hidden_states.shape + batch_size = 1 + else: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.view(-1, hidden_dim) + flat_top_k = top_k_index.view(-1) + num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() + sorted_indices = torch.argsort(flat_top_k, stable=True) + token_indices = sorted_indices // top_k_index.shape[-1] + permuted_input = hidden_states[token_indices] + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + use_separated_lora = True + model_type = getattr(self, "_unsloth_model_type", None) + + gate_up_prepared = _prepare_scaled_grouped_mm_weight(self, "gate_up_proj", "gate_up", hidden_dim, model_type) + down_prepared = _prepare_scaled_grouped_mm_weight(self, "down_proj", "down", hidden_dim, model_type) + if gate_up_prepared is None or down_prepared is None: + return None + gate_up_weight, gate_up_scale = gate_up_prepared + down_weight, down_scale = down_prepared + + target_dtype = _get_fp8_dequant_target_dtype(permuted_input) + permuted_input_fp8, input_scale = _manual_fp8_rowwise_quantize(permuted_input.to(target_dtype)) + try: + mm1_out = torch._scaled_grouped_mm( + permuted_input_fp8, + gate_up_weight, + input_scale, + gate_up_scale, + offs=offsets, + out_dtype=target_dtype, + use_fast_accum=True, + ) + except RuntimeError: + return None + + gate_up_lora = _get_grouped_lora(self, "gate_up_proj", "_unsloth_lora_gate_up_proj", use_separated_lora) + if gate_up_lora is not None: + mm1_out = mm1_out + _apply_grouped_lora(permuted_input, gate_up_lora, offsets, mm1_out.dtype) + if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: + bias_expanded = _expand_grouped_bias(self.gate_up_proj_bias, num_tokens_per_expert) + mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) + + if "GptOssExperts" in self.__class__.__name__: + gate = mm1_out[..., ::2] + up = mm1_out[..., 1::2] + limit = getattr(self, "limit", 7.0) + alpha = getattr(self, "alpha", 1.702) + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) + else: + gate, up = mm1_out.chunk(2, dim=-1) + inter = F.silu(gate) * up + + inter_fp8, inter_scale = _manual_fp8_rowwise_quantize(inter) + try: + mm2_out = torch._scaled_grouped_mm( + inter_fp8, + down_weight, + inter_scale, + down_scale, + offs=offsets, + out_dtype=mm1_out.dtype, + use_fast_accum=True, + ) + except RuntimeError: + return None + + down_lora = _get_grouped_lora(self, "down_proj", "_unsloth_lora_down_proj", use_separated_lora) + if down_lora is not None: + mm2_out = mm2_out + _apply_grouped_lora(inter, down_lora, offsets, inter.dtype) + if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: + bias_expanded = _expand_grouped_bias(self.down_proj_bias, num_tokens_per_expert).to(mm2_out.device) + mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) + + flat_weights = top_k_weights.view(-1) + permuted_weights = flat_weights[sorted_indices] + mm2_out = mm2_out * permuted_weights.unsqueeze(-1) + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=input_dtype, + device=hidden_states.device, + ) + final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype)) + if is_2d_input: + return final_hidden_states + return final_hidden_states.view(batch_size, sequence_length, hidden_dim) + + +def _moe_uses_fp8_expert_weights(self) -> bool: + from .moe_utils import _get_moe_weight_and_quant_state + + if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): + return False + gate_param = getattr(self, "gate_up_proj", None) + down_param = getattr(self, "down_proj", None) + if _is_float8_tensor(gate_param) or _is_float8_tensor(down_param): + return True + gate_weight, _ = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_weight, _ = _get_moe_weight_and_quant_state(self, "down_proj") + return _is_float8_tensor(gate_weight) or _is_float8_tensor(down_weight) + + +def _call_with_temporary_moe_weights(experts_module, gate_up_proj, down_proj, forward_fn, *args): + old_gate_up = getattr(experts_module, "gate_up_proj") + old_down = getattr(experts_module, "down_proj") + gate_up_param = nn.Parameter(gate_up_proj, requires_grad=old_gate_up.requires_grad) + down_param = nn.Parameter(down_proj, requires_grad=old_down.requires_grad) + setattr(experts_module, "gate_up_proj", gate_up_param) + setattr(experts_module, "down_proj", down_param) + try: + return forward_fn(experts_module, *args) + finally: + setattr(experts_module, "gate_up_proj", old_gate_up) + setattr(experts_module, "down_proj", old_down) + + +def _slice_fp8_linear_quant_state(experts_module, param_name: str, expert_idx: int): + weight, quant_state, quant_kind = _get_moe_weight_and_quant_info(experts_module, param_name) + expert_quant_state = _slice_fp8_quant_state(weight, quant_state, expert_idx) + if not isinstance(expert_quant_state, torch.Tensor): + return expert_quant_state + if quant_kind == "weight_scale_inv": + expert_quant_state = expert_quant_state.reciprocal() + if expert_quant_state.ndim == 1: + expert_quant_state = expert_quant_state.view(-1, 1) + return expert_quant_state + + +def _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weights): + from unsloth.kernels.fp8 import fp8_linear + + is_2d_input = hidden_states.dim() == 2 + if is_2d_input: + sequence_length, hidden_dim = hidden_states.shape + batch_size = 1 + else: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.view(-1, hidden_dim) + final_hidden_states = torch.zeros_like(hidden_states) + + gate_up_weight, _, _ = _get_moe_weight_and_quant_info(self, "gate_up_proj") + down_weight, _, _ = _get_moe_weight_and_quant_info(self, "down_proj") + + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + gate_up_bias = getattr(self, "gate_up_proj_bias", None) + down_bias = getattr(self, "down_proj_bias", None) + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + + expert_gate_up = gate_up_weight[expert_idx] + gate_up_qstate = _slice_fp8_linear_quant_state(self, "gate_up_proj", expert_idx) + gate_up_bias_expert = None if gate_up_bias is None else gate_up_bias[expert_idx] + if _is_float8_tensor(expert_gate_up): + gate_up_out = fp8_linear(current_state, expert_gate_up, gate_up_qstate, gate_up_bias_expert) + else: + gate_up_out = F.linear(current_state, expert_gate_up, gate_up_bias_expert) + + if "GptOssExperts" in self.__class__.__name__: + gate = gate_up_out[..., ::2] + up = gate_up_out[..., 1::2] + limit = getattr(self, "limit", 7.0) + alpha = getattr(self, "alpha", 1.702) + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + current_hidden_states = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) + else: + gate, up = gate_up_out.chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + + expert_down = down_weight[expert_idx] + down_qstate = _slice_fp8_linear_quant_state(self, "down_proj", expert_idx) + down_bias_expert = None if down_bias is None else down_bias[expert_idx] + if _is_float8_tensor(expert_down): + current_hidden_states = fp8_linear( + current_hidden_states, + expert_down, + down_qstate, + down_bias_expert, + ) + else: + current_hidden_states = F.linear(current_hidden_states, expert_down, down_bias_expert) + + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(input_dtype)) + + if is_2d_input: + return final_hidden_states + return final_hidden_states.view(batch_size, sequence_length, hidden_dim) + + +def _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights): + from .moe_utils import _get_moe_weight_and_quant_state, forward_native_moe_loop + + _log_moe_fp8_backend_once( + self, + "Unsloth: MoE FP8 fallback is using the direct FP8 expert loop backend.", + ) + try: + return _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weights) + except Exception: + pass + + target_dtype = _get_fp8_dequant_target_dtype(hidden_states) + gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") + gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) + down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) + if gate_up_weight is None or down_weight is None: + raise RuntimeError( + "Unable to dequantize FP8 MoE expert weights for the eager fallback path." + ) + + return _call_with_temporary_moe_weights( + self, + gate_up_weight, + down_weight, + forward_native_moe_loop, + hidden_states.to(target_dtype), + top_k_index, + top_k_weights, + ) + + +def forward_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights): + from .moe_utils import ( + _get_moe_weight_and_quant_state, + forward_native_grouped_mm, + forward_native_moe_loop, + ) + + if not _moe_uses_fp8_expert_weights(self): + return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights) + + scaled_output = _forward_scaled_grouped_mm_fp8( + self, hidden_states, top_k_index, top_k_weights + ) + if scaled_output is not None: + _log_moe_fp8_backend_once( + self, + "Unsloth: MoE FP8 is using torch._scaled_grouped_mm.", + ) + return scaled_output + + target_dtype = _get_fp8_dequant_target_dtype(hidden_states) + gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") + gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) + down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) + if gate_up_weight is None or down_weight is None: + return _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights) + + _log_moe_fp8_backend_once( + self, + "Unsloth: MoE FP8 is using dequantize-plus-grouped_mm.", + ) + return _call_with_temporary_moe_weights( + self, + gate_up_weight, + down_weight, + forward_native_grouped_mm, + hidden_states.to(target_dtype), + top_k_index, + top_k_weights, + ) + + +def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): + from .moe_utils import select_moe_backend + + backend = select_moe_backend() + if backend == "grouped_mm": + return forward_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights) + if backend == "unsloth_triton": + return forward_triton_grouped_gemm_fp8(self, hidden_states, top_k_index, top_k_weights) + return _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights) + + +def forward_triton_grouped_gemm_fp8(self, hidden_states, top_k_index, top_k_weights): + from .moe_utils import ( + _check_torch_grouped_mm_supported, + _get_moe_weight_and_quant_state, + forward_native_moe_loop, + forward_triton_grouped_gemm, + ) + + if not _moe_uses_fp8_expert_weights(self): + return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights) + if _check_torch_grouped_mm_supported(): + return forward_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights) + + target_dtype = _get_fp8_dequant_target_dtype(hidden_states) + gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") + down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") + gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) + down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) + if gate_up_weight is None or down_weight is None: + return _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights) + + _log_moe_fp8_backend_once( + self, + "Unsloth: MoE FP8 is using dequantize-plus-Triton grouped GEMM.", + ) + return _call_with_temporary_moe_weights( + self, + gate_up_weight, + down_weight, + forward_triton_grouped_gemm, + hidden_states.to(target_dtype), + top_k_index, + top_k_weights, + ) From a21ad7e2c1ec999876173bd70b6a3174136932b7 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 16 Mar 2026 14:13:07 +0000 Subject: [PATCH 6/8] Fixup dequant --- .../temporary_patches/moe_utils_fp8.py | 215 ++++++++---------- 1 file changed, 96 insertions(+), 119 deletions(-) diff --git a/unsloth_zoo/temporary_patches/moe_utils_fp8.py b/unsloth_zoo/temporary_patches/moe_utils_fp8.py index c12d1776a..4d8696cff 100644 --- a/unsloth_zoo/temporary_patches/moe_utils_fp8.py +++ b/unsloth_zoo/temporary_patches/moe_utils_fp8.py @@ -161,11 +161,13 @@ def _get_fp8_dequant_target_dtype(hidden_states: torch.Tensor) -> torch.dtype: def _log_moe_fp8_backend_once(experts_module, message: str): + from .common import logger from .moe_utils import _log_info if getattr(experts_module, "_unsloth_logged_fp8_backend", None) == message: return experts_module._unsloth_logged_fp8_backend = message + logger.info(message) _log_info(message) @@ -239,51 +241,79 @@ def _slice_fp8_quant_state(weight: torch.Tensor, quant_state, expert_idx: int): return sliced +def _ceil_div(a, b): + return (a + b - 1) // b + + def _dequantize_expert_slice( expert_weight: torch.Tensor, expert_quant_state, target_dtype: torch.dtype, ) -> Optional[torch.Tensor]: + """Dequantize one expert's FP8 weight to target_dtype using pure PyTorch.""" from .moe_utils import _try_attach_block_size if expert_weight.dtype != torch.float8_e4m3fn: return expert_weight.to(target_dtype) if expert_quant_state is None: - return None + return expert_weight.to(target_dtype) - try: - from unsloth.kernels.fp8 import weight_dequant - import triton - except Exception: - return None + s = expert_quant_state + if not isinstance(s, torch.Tensor): + return expert_weight.to(target_dtype) - block_size = getattr(expert_weight, "block_size", None) or getattr(expert_quant_state, "block_size", None) - if block_size is not None: - _try_attach_block_size(expert_weight, block_size) - _try_attach_block_size(expert_quant_state, block_size) + w = expert_weight.to(target_dtype) + # Per-tensor scale + if s.numel() == 1: + return w * s.view(1, 1).to(target_dtype) + + # Reshape 1D to column vector for per-row handling + if s.ndim == 1: + s = s.view(-1, 1) + + # Per-row scale: (m, 1) + if s.ndim == 2 and s.shape[1] == 1: + # Per-sub-projection scalar scales (e.g. 2 scales for gate+up stacked weight). if ( - isinstance(expert_quant_state, torch.Tensor) - and expert_quant_state.ndim == 2 - and len(block_size) == 2 + s.shape[0] > 1 + and s.shape[0] < w.shape[0] + and w.shape[0] % s.shape[0] == 0 ): - m, n = expert_weight.shape - p, q = expert_quant_state.shape - if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q: - if ( - triton.cdiv(m, block_size[0]) == q - and triton.cdiv(n, block_size[1]) == p - ): - expert_quant_state = expert_quant_state.T.contiguous() - _try_attach_block_size(expert_quant_state, block_size) + repeat_factor = w.shape[0] // s.shape[0] + s = s.repeat_interleave(repeat_factor, dim=0) + + if w.shape[0] == s.shape[0]: + return w * s.to(target_dtype) + elif w.shape[1] == s.shape[0]: + return (w.t() * s.to(target_dtype)).t() + return w * s.to(target_dtype) + + # Block scale: (ceil(m/bm), ceil(n/bn)) — expand to weight shape + if s.ndim == 2: + block_size = getattr(expert_weight, "block_size", None) or getattr(s, "block_size", None) + M, N = w.shape + p, q = s.shape + + if block_size is not None and len(block_size) == 2: + bm, bn = block_size + # Check if scale is transposed + if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q: + if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p: + s = s.T.contiguous() + p, q = s.shape else: - return None + return expert_weight.to(target_dtype) + else: + # Infer block size from scale grid + bm = _ceil_div(M, p) + bn = _ceil_div(N, q) - if isinstance(expert_quant_state, torch.Tensor) and expert_quant_state.ndim == 1: - expert_quant_state = expert_quant_state.view(-1, 1) + s_expanded = s.to(target_dtype).repeat_interleave(bm, dim=0)[:M].repeat_interleave(bn, dim=1)[:, :N] + return w * s_expanded - return weight_dequant(expert_weight, expert_quant_state, dtype=target_dtype) + return expert_weight.to(target_dtype) def _dequantize_full_expert_weights(weight: torch.Tensor, quant_state, target_dtype: torch.dtype): @@ -441,18 +471,15 @@ def _forward_scaled_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weigh target_dtype = _get_fp8_dequant_target_dtype(permuted_input) permuted_input_fp8, input_scale = _manual_fp8_rowwise_quantize(permuted_input.to(target_dtype)) - try: - mm1_out = torch._scaled_grouped_mm( - permuted_input_fp8, - gate_up_weight, - input_scale, - gate_up_scale, - offs=offsets, - out_dtype=target_dtype, - use_fast_accum=True, - ) - except RuntimeError: - return None + mm1_out = torch._scaled_grouped_mm( + permuted_input_fp8, + gate_up_weight, + input_scale, + gate_up_scale, + offs=offsets, + out_dtype=target_dtype, + use_fast_accum=True, + ) gate_up_lora = _get_grouped_lora(self, "gate_up_proj", "_unsloth_lora_gate_up_proj", use_separated_lora) if gate_up_lora is not None: @@ -474,18 +501,15 @@ def _forward_scaled_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weigh inter = F.silu(gate) * up inter_fp8, inter_scale = _manual_fp8_rowwise_quantize(inter) - try: - mm2_out = torch._scaled_grouped_mm( - inter_fp8, - down_weight, - inter_scale, - down_scale, - offs=offsets, - out_dtype=mm1_out.dtype, - use_fast_accum=True, - ) - except RuntimeError: - return None + mm2_out = torch._scaled_grouped_mm( + inter_fp8, + down_weight, + inter_scale, + down_scale, + offs=offsets, + out_dtype=mm1_out.dtype, + use_fast_accum=True, + ) down_lora = _get_grouped_lora(self, "down_proj", "_unsloth_lora_down_proj", use_separated_lora) if down_lora is not None: @@ -627,12 +651,8 @@ def _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights _log_moe_fp8_backend_once( self, - "Unsloth: MoE FP8 fallback is using the direct FP8 expert loop backend.", + "Unsloth: MoE FP8 is using the native_torch fallback backend.", ) - try: - return _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weights) - except Exception: - pass target_dtype = _get_fp8_dequant_target_dtype(hidden_states) gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") @@ -655,90 +675,47 @@ def _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights ) -def forward_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights): +def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): from .moe_utils import ( _get_moe_weight_and_quant_state, + select_moe_backend, forward_native_grouped_mm, + forward_triton_grouped_gemm, forward_native_moe_loop, ) - if not _moe_uses_fp8_expert_weights(self): - return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights) - - scaled_output = _forward_scaled_grouped_mm_fp8( - self, hidden_states, top_k_index, top_k_weights - ) - if scaled_output is not None: - _log_moe_fp8_backend_once( - self, - "Unsloth: MoE FP8 is using torch._scaled_grouped_mm.", - ) - return scaled_output - + # Dequant FP8 weights to bf16, then use preferred non-FP8 backend. + # Note: _scaled_grouped_mm is NOT attempted here because the small-matrix + # probe can pass while real-sized matrices generate incompatible MMA + # instructions (e.g. on B200/SM100), which poisons the CUDA context + # asynchronously and cannot be caught without try/except. target_dtype = _get_fp8_dequant_target_dtype(hidden_states) gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) - if gate_up_weight is None or down_weight is None: - return _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights) - - _log_moe_fp8_backend_once( - self, - "Unsloth: MoE FP8 is using dequantize-plus-grouped_mm.", - ) - return _call_with_temporary_moe_weights( - self, - gate_up_weight, - down_weight, - forward_native_grouped_mm, - hidden_states.to(target_dtype), - top_k_index, - top_k_weights, - ) - -def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): - from .moe_utils import select_moe_backend + if gate_up_weight is None or down_weight is None: + raise RuntimeError( + "Unable to dequantize FP8 MoE expert weights." + ) backend = select_moe_backend() if backend == "grouped_mm": - return forward_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights) - if backend == "unsloth_triton": - return forward_triton_grouped_gemm_fp8(self, hidden_states, top_k_index, top_k_weights) - return _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights) - - -def forward_triton_grouped_gemm_fp8(self, hidden_states, top_k_index, top_k_weights): - from .moe_utils import ( - _check_torch_grouped_mm_supported, - _get_moe_weight_and_quant_state, - forward_native_moe_loop, - forward_triton_grouped_gemm, - ) - - if not _moe_uses_fp8_expert_weights(self): - return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights) - if _check_torch_grouped_mm_supported(): - return forward_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weights) - - target_dtype = _get_fp8_dequant_target_dtype(hidden_states) - gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") - down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") - gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) - down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) - if gate_up_weight is None or down_weight is None: - return _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights) + _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-grouped_mm.") + forward_fn = forward_native_grouped_mm + elif backend == "unsloth_triton": + _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-Triton grouped GEMM.") + forward_fn = forward_triton_grouped_gemm + else: + _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-native_torch loop.") + forward_fn = forward_native_moe_loop - _log_moe_fp8_backend_once( - self, - "Unsloth: MoE FP8 is using dequantize-plus-Triton grouped GEMM.", - ) return _call_with_temporary_moe_weights( self, gate_up_weight, down_weight, - forward_triton_grouped_gemm, + forward_fn, hidden_states.to(target_dtype), top_k_index, top_k_weights, From 20f56f8fea0906c2b9573a43edd591af872c63af Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 16 Mar 2026 18:47:16 +0000 Subject: [PATCH 7/8] fix qwen3moe --- unsloth_zoo/temporary_patches/misc.py | 20 +++-- unsloth_zoo/temporary_patches/moe_utils.py | 24 +++--- .../temporary_patches/moe_utils_fp8.py | 80 ++++++++++++------- unsloth_zoo/temporary_patches/qwen3_moe.py | 12 ++- 4 files changed, 86 insertions(+), 50 deletions(-) diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index d1f3c3735..af53c39f1 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -446,16 +446,22 @@ def patch_transformers_masks(): except Exception: torch_create_block_mask = None + # Always disable _compile flag to avoid double compilation issues + # When unsloth compiles create_causal_mask, the internal create_block_mask + # should NOT also compile itself as it causes dimension issues + # We need to patch both masking_utils and the original torch module if torch_create_block_mask is not None: + def create_block_mask_wrapper(*args, **kwargs): + kwargs["_compile"] = False + return torch_create_block_mask(*args, **kwargs) + # Patch masking_utils (for direct access) + masking_utils.create_block_mask = create_block_mask_wrapper + # Also patch the torch module directly (used by flex_attention_mask via import) try: - supports_compile = "_compile" in inspect.signature(torch_create_block_mask).parameters + import torch.nn.attention.flex_attention as flex_attention + flex_attention.create_block_mask = create_block_mask_wrapper except Exception: - supports_compile = True - if not supports_compile: - def create_block_mask_wrapper(*args, **kwargs): - kwargs.pop("_compile", None) - return torch_create_block_mask(*args, **kwargs) - masking_utils.create_block_mask = create_block_mask_wrapper + pass original_create_causal_mask = getattr( masking_utils, diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index b0912b8f7..63ccc0723 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -333,7 +333,7 @@ def forward_moe_backend( return forward_moe_backend_fp8( self, hidden_states, top_k_index, top_k_weights ) - except Exception: + except ImportError: pass backend = select_moe_backend() @@ -576,6 +576,9 @@ def _get_moe_weight_and_quant_state(experts_module, param_name: str): block_size = getattr(param, "block_size", None) if block_size is None: block_size = getattr(experts_module, f"{param_name}_block_size", None) + if block_size is None: + # FP8Experts stores block_size on the module itself + block_size = getattr(experts_module, "block_size", None) if block_size is not None: _try_attach_block_size(weight, block_size) if quant_state is not None: @@ -961,21 +964,22 @@ def forward_native_grouped_mm( hidden_states = hidden_states.view(-1, hidden_dim) - # 1. Calculate routing - flat_top_k = top_k_index.view(-1) - num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() + # 1. Calculate routing (no grad needed for routing indices - they come from router's topk) + with torch.no_grad(): + flat_top_k = top_k_index.view(-1) + num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() + + # 2. Sort indices to group tokens by expert + sorted_indices = torch.argsort(flat_top_k, stable=True) + token_indices = sorted_indices // top_k_index.shape[-1] - # 2. Sort indices to group tokens by expert - sorted_indices = torch.argsort(flat_top_k, stable=True) - token_indices = sorted_indices // top_k_index.shape[-1] + # 4. Prepare Grouped MM arguments + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # 3. Permute Input # We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input permuted_input = hidden_states[token_indices] - # 4. Prepare Grouped MM arguments - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # ======================================================================== # Gate + Up projection with optional separated LoRA (DEFAULT) # ======================================================================== diff --git a/unsloth_zoo/temporary_patches/moe_utils_fp8.py b/unsloth_zoo/temporary_patches/moe_utils_fp8.py index 4d8696cff..671a4580c 100644 --- a/unsloth_zoo/temporary_patches/moe_utils_fp8.py +++ b/unsloth_zoo/temporary_patches/moe_utils_fp8.py @@ -183,11 +183,12 @@ def _check_torch_scaled_grouped_mm_supported(): _TORCH_SCALED_GROUPED_MM_SUPPORTED = False return False - # The symbol can exist on older GPUs, but probing it on unsupported - # hardware can trigger an async launch failure that poisons the CUDA - # context. Keep the FP8 scaled_grouped_mm path off on pre-Hopper parts. + # The symbol can exist on unsupported GPUs, and the light probe can still + # pass on SM100 while real MoE kernels later emit incompatible MMA + # instructions and poison the CUDA context asynchronously. Restrict the + # FP8 scaled_grouped_mm path to Hopper (SM 9.x) only for now. major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) - if major < 9: + if major != 9: _TORCH_SCALED_GROUPED_MM_SUPPORTED = False return False @@ -367,6 +368,9 @@ def _get_moe_weight_and_quant_info(experts_module, param_name: str): block_size = getattr(param, "block_size", None) if block_size is None: block_size = getattr(experts_module, f"{param_name}_block_size", None) + if block_size is None: + # FP8Experts stores block_size on the module itself + block_size = getattr(experts_module, "block_size", None) if block_size is not None: _try_attach_block_size(weight, block_size) if quant_state is not None: @@ -675,6 +679,7 @@ def _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights ) +@torch.compiler.disable def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): from .moe_utils import ( _get_moe_weight_and_quant_state, @@ -684,39 +689,54 @@ def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): forward_native_moe_loop, ) - # Dequant FP8 weights to bf16, then use preferred non-FP8 backend. - # Note: _scaled_grouped_mm is NOT attempted here because the small-matrix - # probe can pass while real-sized matrices generate incompatible MMA - # instructions (e.g. on B200/SM100), which poisons the CUDA context - # asynchronously and cannot be caught without try/except. + backend = select_moe_backend() + + # 1. Try _scaled_grouped_mm (fast FP8 path on Hopper/Blackwell) + if backend == "grouped_mm" and _check_torch_scaled_grouped_mm_supported(): + scaled_grouped_mm_output = _forward_scaled_grouped_mm_fp8( + self, + hidden_states, + top_k_index, + top_k_weights, + ) + if scaled_grouped_mm_output is not None: + _log_moe_fp8_backend_once( + self, + "Unsloth: MoE FP8 is using _scaled_grouped_mm.", + ) + return scaled_grouped_mm_output + + # 2. Dequant FP8 weights to bf16/fp16 and run through normal MoE forward target_dtype = _get_fp8_dequant_target_dtype(hidden_states) gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) - if gate_up_weight is None or down_weight is None: - raise RuntimeError( - "Unable to dequantize FP8 MoE expert weights." + if gate_up_weight is not None and down_weight is not None: + if backend == "grouped_mm": + _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-grouped_mm.") + forward_fn = forward_native_grouped_mm + elif backend == "unsloth_triton": + _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-Triton grouped GEMM.") + forward_fn = forward_triton_grouped_gemm + else: + _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-native_torch loop.") + forward_fn = forward_native_moe_loop + + return _call_with_temporary_moe_weights( + self, + gate_up_weight, + down_weight, + forward_fn, + hidden_states.to(target_dtype), + top_k_index, + top_k_weights, ) - backend = select_moe_backend() - if backend == "grouped_mm": - _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-grouped_mm.") - forward_fn = forward_native_grouped_mm - elif backend == "unsloth_triton": - _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-Triton grouped GEMM.") - forward_fn = forward_triton_grouped_gemm - else: - _log_moe_fp8_backend_once(self, "Unsloth: MoE FP8 is using dequantize-plus-native_torch loop.") - forward_fn = forward_native_moe_loop - - return _call_with_temporary_moe_weights( + # 3. Last resort: per-expert fp8_linear loop + _log_moe_fp8_backend_once( self, - gate_up_weight, - down_weight, - forward_fn, - hidden_states.to(target_dtype), - top_k_index, - top_k_weights, + "Unsloth: MoE FP8 is using per-expert fp8_linear loop.", ) + return _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weights) diff --git a/unsloth_zoo/temporary_patches/qwen3_moe.py b/unsloth_zoo/temporary_patches/qwen3_moe.py index bd03c2cc5..16986caa9 100644 --- a/unsloth_zoo/temporary_patches/qwen3_moe.py +++ b/unsloth_zoo/temporary_patches/qwen3_moe.py @@ -116,6 +116,7 @@ def _make_qwen_moe_sparse_moe_block_forward(use_shared_expert: bool, module_name @torch.compiler.disable def sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: use_shared_expert = hasattr(self, "shared_expert") and hasattr(self, "shared_expert_gate") + input_dtype = hidden_states.dtype if hidden_states.dim() == 3: batch_size, sequence_length, hidden_dim = hidden_states.shape else: @@ -124,12 +125,16 @@ def sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: sequence_length = total_tokens hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + router_input = hidden_states_reshaped + gate_weight = getattr(self.gate, "weight", None) + if gate_weight is not None and router_input.dtype != gate_weight.dtype: + router_input = router_input.to(gate_weight.dtype) shared_expert_output = None if use_shared_expert: shared_expert_output = self.shared_expert(hidden_states_reshaped) - gate_output = self.gate(hidden_states_reshaped) + gate_output = self.gate(router_input) if isinstance(gate_output, tuple): _, routing_weights, selected_experts = gate_output else: @@ -141,14 +146,15 @@ def sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) if norm_topk_prob: routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) + routing_weights = routing_weights.to(router_input.dtype) - final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + final_hidden_states = self.experts(router_input, selected_experts, routing_weights) if use_shared_expert: shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output + final_hidden_states = final_hidden_states.to(input_dtype) if hidden_states.dim() == 3: return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states From 7798e589b03e124281a733333c0a6bd3af5681aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 17 Mar 2026 08:22:26 +0000 Subject: [PATCH 8/8] Fix bugs in FP8 MoE support (#548) - B7: Support sharded safetensors (multi-shard) in GLM4 FP8 scale patching - B8: Add quant_kind param to _dequantize_expert_slice for weight_scale_inv reciprocal handling - B18: Guard fp8_linear import with try/except and dequant fallback - B19: Flatten 3D top_k_index/top_k_weights alongside hidden_states - B23: Fix _make_grouped_mm_rhs_column_major (was no-op double transpose, now weight.mT.contiguous()) - B25: Add act_fn fallback to F.silu when attribute missing - B26: Remove dead _forward_native_moe_loop_fp8 function - B10: Add 3D input reshape in forward_native_moe_loop - B12: Prefer generic backend over unconditional FP8 in get_forward_moe_backend; use forward_moe_backend as final fallback - Fix use_separated_lora to respect _should_use_separated_lora() instead of hardcoded True - Remove no-op try/except RuntimeError and duplicate comment --- unsloth_zoo/temporary_patches/moe_utils.py | 33 ++- .../temporary_patches/moe_utils_fp8.py | 275 +++++++++++------- 2 files changed, 186 insertions(+), 122 deletions(-) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 63ccc0723..1c51db1f7 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -135,21 +135,21 @@ def _load_cached_moe_utils_fp8_module(): def get_forward_moe_backend(): """ Resolve forward_moe_backend from the compiled cache copy when available. - Falls back to the local module definition. + Prefer the generic (non-FP8) backend; the generic backend can internally + detect FP8 weights and dispatch to the FP8 path if needed. """ global _CACHED_FORWARD_MOE_BACKEND - fp8_module = _load_cached_moe_utils_fp8_module() - if fp8_module is not None and hasattr(fp8_module, "forward_moe_backend_fp8"): - _CACHED_FORWARD_MOE_BACKEND = fp8_module.forward_moe_backend_fp8 - return _CACHED_FORWARD_MOE_BACKEND - module = _load_cached_moe_utils_module() if module is not None and hasattr(module, "forward_moe_backend"): _CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend return _CACHED_FORWARD_MOE_BACKEND - from .moe_utils_fp8 import forward_moe_backend_fp8 - _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend_fp8 + fp8_module = _load_cached_moe_utils_fp8_module() + if fp8_module is not None and hasattr(fp8_module, "forward_moe_backend_fp8"): + _CACHED_FORWARD_MOE_BACKEND = fp8_module.forward_moe_backend_fp8 + return _CACHED_FORWARD_MOE_BACKEND + + _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend return _CACHED_FORWARD_MOE_BACKEND # ============================================================================ @@ -1025,11 +1025,8 @@ def forward_native_grouped_mm( second_weight = second_weight.to(permuted_input.dtype).contiguous() # Step 1: permuted_input @ first_weight - try: - lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets) - lora_out = lora_out.contiguous() - except RuntimeError as e: - raise e + lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets) + lora_out = lora_out.contiguous() # Step 2: result @ second_weight # Handle unaligned O dimension or other grouped_mm failures @@ -1353,8 +1350,6 @@ def forward_triton_grouped_gemm( # Apply SiLU activation and multiply gate with up intermediate = _silu_and_mul(first_gemm_output) - # Grouped GEMM 2: down projection - # Grouped GEMM 2: down projection # Prepare LoRA data down_lora = None @@ -1436,6 +1431,12 @@ def forward_native_moe_loop( Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow. """ # This Unsloth Zoo code section is licensed under AGPL3 + original_shape = hidden_states.shape + if hidden_states.dim() == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + top_k_index = top_k_index.view(-1, top_k_index.shape[-1]) + top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1]) + final_hidden_states = torch.zeros_like(hidden_states) # Create expert mask and find which experts have tokens @@ -1484,4 +1485,4 @@ def forward_native_moe_loop( 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) ) - return final_hidden_states + return final_hidden_states.view(original_shape) diff --git a/unsloth_zoo/temporary_patches/moe_utils_fp8.py b/unsloth_zoo/temporary_patches/moe_utils_fp8.py index 671a4580c..62eaddf8a 100644 --- a/unsloth_zoo/temporary_patches/moe_utils_fp8.py +++ b/unsloth_zoo/temporary_patches/moe_utils_fp8.py @@ -58,72 +58,151 @@ def _maybe_patch_glm4_stacked_moe_fp8_scales( if len(routed_layers) == 0: return False + import safetensors.torch + import json as _json + + # Collect all tensor keys we'll need so we can find the right shards + needed_keys = set() + for layer_idx, experts in routed_layers: + num_experts = experts.gate_up_proj.shape[0] + for expert_idx in range(num_experts): + for proj in ("gate_proj", "up_proj", "down_proj"): + needed_keys.add(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.{proj}.weight") + needed_keys.add(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.{proj}.weight_scale") + + # Resolve file path(s) -- single file or sharded + shard_paths = {} # tensor_key -> local file path if os.path.isdir(model_name): - safetensors_path = os.path.join(model_name, "model.safetensors") + single_path = os.path.join(model_name, "model.safetensors") + if os.path.exists(single_path): + shard_paths = {k: single_path for k in needed_keys} + else: + index_path = os.path.join(model_name, "model.safetensors.index.json") + if not os.path.exists(index_path): + return False + with open(index_path) as f: + weight_map = _json.load(f).get("weight_map", {}) + for k in needed_keys: + shard_file = weight_map.get(k) + if shard_file is None: + return False + shard_paths[k] = os.path.join(model_name, shard_file) else: from huggingface_hub import hf_hub_download - safetensors_path = hf_hub_download( - repo_id = model_name, - filename = "model.safetensors", - token = token, - revision = revision, - ) + try: + single_path = hf_hub_download( + repo_id = model_name, + filename = "model.safetensors", + token = token, + revision = revision, + ) + shard_paths = {k: single_path for k in needed_keys} + except Exception: + try: + index_path = hf_hub_download( + repo_id = model_name, + filename = "model.safetensors.index.json", + token = token, + revision = revision, + ) + with open(index_path) as f: + weight_map = _json.load(f).get("weight_map", {}) + # Download only the unique shards we need + needed_shards = set() + for k in needed_keys: + shard_file = weight_map.get(k) + if shard_file is None: + return False + needed_shards.add(shard_file) + downloaded = {} + for shard_file in needed_shards: + downloaded[shard_file] = hf_hub_download( + repo_id = model_name, + filename = shard_file, + token = token, + revision = revision, + ) + for k in needed_keys: + shard_paths[k] = downloaded[weight_map[k]] + except Exception: + return False + + if not shard_paths: + return False - import safetensors.torch + # Open all unique shard files and build a multi-shard reader + unique_paths = set(shard_paths.values()) + open_handles = {} + try: + for p in unique_paths: + open_handles[p] = safetensors.torch.safe_open(p, framework = "pt") - with safetensors.torch.safe_open(safetensors_path, framework = "pt") as file: - for layer_idx, experts in routed_layers: - device = experts.gate_up_proj.device - num_experts = experts.gate_up_proj.shape[0] - gate_up_rows = [] - down_rows = [] - gate_up_scales = [] - down_scales = [] - - for expert_idx in range(num_experts): - gate = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" - ) - gate_scale = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale" - ) - up = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" - ) - up_scale = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale" - ) - down = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" - ) - down_scale = file.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale" - ) + class _MultiShardFile: + def get_tensor(self, key): + return open_handles[shard_paths[key]].get_tensor(key) - gate_up_rows.append(torch.cat([gate, up], dim = 0)) - down_rows.append(down) - gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0)) - down_scales.append(down_scale) + file = _MultiShardFile() + _patch_result = _do_glm4_scale_patching(file, routed_layers) + finally: + open_handles.clear() + + return _patch_result + + +def _do_glm4_scale_patching(file, routed_layers): + """Inner helper that reads tensors from file and patches routed_layers.""" + for layer_idx, experts in routed_layers: + device = experts.gate_up_proj.device + num_experts = experts.gate_up_proj.shape[0] + gate_up_rows = [] + down_rows = [] + gate_up_scales = [] + down_scales = [] - experts.gate_up_proj = nn.Parameter( - torch.stack(gate_up_rows, dim = 0).to(device = device), - requires_grad = experts.gate_up_proj.requires_grad, + for expert_idx in range(num_experts): + gate = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" ) - experts.down_proj = nn.Parameter( - torch.stack(down_rows, dim = 0).to(device = device), - requires_grad = experts.down_proj.requires_grad, + gate_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale" ) - experts.gate_up_proj_weight_scale = nn.Parameter( - torch.stack(gate_up_scales, dim = 0).to(device = device), - requires_grad = False, + up = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" ) - experts.down_proj_weight_scale = nn.Parameter( - torch.stack(down_scales, dim = 0).to(device = device), - requires_grad = False, + up_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale" ) - experts.gate_up_proj_scale = experts.gate_up_proj_weight_scale - experts.down_proj_scale = experts.down_proj_weight_scale + down = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_scale = file.get_tensor( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale" + ) + + gate_up_rows.append(torch.cat([gate, up], dim = 0)) + down_rows.append(down) + gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0)) + down_scales.append(down_scale) + + experts.gate_up_proj = nn.Parameter( + torch.stack(gate_up_rows, dim = 0).to(device = device), + requires_grad = experts.gate_up_proj.requires_grad, + ) + experts.down_proj = nn.Parameter( + torch.stack(down_rows, dim = 0).to(device = device), + requires_grad = experts.down_proj.requires_grad, + ) + experts.gate_up_proj_weight_scale = nn.Parameter( + torch.stack(gate_up_scales, dim = 0).to(device = device), + requires_grad = False, + ) + experts.down_proj_weight_scale = nn.Parameter( + torch.stack(down_scales, dim = 0).to(device = device), + requires_grad = False, + ) + experts.gate_up_proj_scale = experts.gate_up_proj_weight_scale + experts.down_proj_scale = experts.down_proj_weight_scale return True @@ -250,6 +329,7 @@ def _dequantize_expert_slice( expert_weight: torch.Tensor, expert_quant_state, target_dtype: torch.dtype, + quant_kind=None, ) -> Optional[torch.Tensor]: """Dequantize one expert's FP8 weight to target_dtype using pure PyTorch.""" from .moe_utils import _try_attach_block_size @@ -264,6 +344,9 @@ def _dequantize_expert_slice( if not isinstance(s, torch.Tensor): return expert_weight.to(target_dtype) + if quant_kind == "weight_scale_inv": + s = s.reciprocal() + w = expert_weight.to(target_dtype) # Per-tensor scale @@ -317,7 +400,7 @@ def _dequantize_expert_slice( return expert_weight.to(target_dtype) -def _dequantize_full_expert_weights(weight: torch.Tensor, quant_state, target_dtype: torch.dtype): +def _dequantize_full_expert_weights(weight: torch.Tensor, quant_state, target_dtype: torch.dtype, quant_kind=None): from .moe_utils import _try_attach_block_size if weight.ndim != 3: @@ -330,7 +413,7 @@ def _dequantize_full_expert_weights(weight: torch.Tensor, quant_state, target_dt if block_size is not None: _try_attach_block_size(expert_weight, block_size) expert_quant_state = _slice_fp8_quant_state(weight, quant_state, expert_idx) - expert_dequant = _dequantize_expert_slice(expert_weight, expert_quant_state, target_dtype) + expert_dequant = _dequantize_expert_slice(expert_weight, expert_quant_state, target_dtype, quant_kind=quant_kind) if expert_dequant is None: return None dequantized.append(expert_dequant) @@ -338,7 +421,7 @@ def _dequantize_full_expert_weights(weight: torch.Tensor, quant_state, target_dt def _make_grouped_mm_rhs_column_major(weight: torch.Tensor) -> torch.Tensor: - return weight.transpose(-2, -1).contiguous().transpose(-2, -1) + return weight.mT.contiguous() def _get_moe_weight_and_quant_info(experts_module, param_name: str): @@ -463,7 +546,8 @@ def _forward_scaled_grouped_mm_fp8(self, hidden_states, top_k_index, top_k_weigh token_indices = sorted_indices // top_k_index.shape[-1] permuted_input = hidden_states[token_indices] offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - use_separated_lora = True + from .moe_utils import _should_use_separated_lora + use_separated_lora = _should_use_separated_lora() model_type = getattr(self, "_unsloth_model_type", None) gate_up_prepared = _prepare_scaled_grouped_mm_weight(self, "gate_up_proj", "gate_up", hidden_dim, model_type) @@ -577,17 +661,17 @@ def _slice_fp8_linear_quant_state(experts_module, param_name: str, expert_idx: i def _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weights): - from unsloth.kernels.fp8 import fp8_linear - - is_2d_input = hidden_states.dim() == 2 - if is_2d_input: - sequence_length, hidden_dim = hidden_states.shape - batch_size = 1 - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape + try: + from unsloth.kernels.fp8 import fp8_linear + except ImportError: + fp8_linear = None + original_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] input_dtype = hidden_states.dtype hidden_states = hidden_states.view(-1, hidden_dim) + top_k_index = top_k_index.view(-1, top_k_index.shape[-1]) + top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1]) final_hidden_states = torch.zeros_like(hidden_states) gate_up_weight, _, _ = _get_moe_weight_and_quant_info(self, "gate_up_proj") @@ -612,8 +696,12 @@ def _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weig expert_gate_up = gate_up_weight[expert_idx] gate_up_qstate = _slice_fp8_linear_quant_state(self, "gate_up_proj", expert_idx) gate_up_bias_expert = None if gate_up_bias is None else gate_up_bias[expert_idx] - if _is_float8_tensor(expert_gate_up): + if _is_float8_tensor(expert_gate_up) and fp8_linear is not None: gate_up_out = fp8_linear(current_state, expert_gate_up, gate_up_qstate, gate_up_bias_expert) + elif _is_float8_tensor(expert_gate_up): + target_dtype = _get_fp8_dequant_target_dtype(current_state) + expert_dequant = _dequantize_expert_slice(expert_gate_up, gate_up_qstate, target_dtype) + gate_up_out = F.linear(current_state, expert_dequant, gate_up_bias_expert) else: gate_up_out = F.linear(current_state, expert_gate_up, gate_up_bias_expert) @@ -627,62 +715,37 @@ def _forward_native_fp8_expert_loop(self, hidden_states, top_k_index, top_k_weig current_hidden_states = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) else: gate, up = gate_up_out.chunk(2, dim=-1) - current_hidden_states = self.act_fn(gate) * up + act_fn = getattr(self, "act_fn", None) + if act_fn is None: + act_fn = F.silu + current_hidden_states = act_fn(gate) * up expert_down = down_weight[expert_idx] down_qstate = _slice_fp8_linear_quant_state(self, "down_proj", expert_idx) down_bias_expert = None if down_bias is None else down_bias[expert_idx] - if _is_float8_tensor(expert_down): + if _is_float8_tensor(expert_down) and fp8_linear is not None: current_hidden_states = fp8_linear( current_hidden_states, expert_down, down_qstate, down_bias_expert, ) + elif _is_float8_tensor(expert_down): + target_dtype = _get_fp8_dequant_target_dtype(current_hidden_states) + expert_dequant = _dequantize_expert_slice(expert_down, down_qstate, target_dtype) + current_hidden_states = F.linear(current_hidden_states, expert_dequant, down_bias_expert) else: current_hidden_states = F.linear(current_hidden_states, expert_down, down_bias_expert) current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(input_dtype)) - if is_2d_input: - return final_hidden_states - return final_hidden_states.view(batch_size, sequence_length, hidden_dim) - - -def _forward_native_moe_loop_fp8(self, hidden_states, top_k_index, top_k_weights): - from .moe_utils import _get_moe_weight_and_quant_state, forward_native_moe_loop - - _log_moe_fp8_backend_once( - self, - "Unsloth: MoE FP8 is using the native_torch fallback backend.", - ) - - target_dtype = _get_fp8_dequant_target_dtype(hidden_states) - gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") - down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") - gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) - down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) - if gate_up_weight is None or down_weight is None: - raise RuntimeError( - "Unable to dequantize FP8 MoE expert weights for the eager fallback path." - ) - - return _call_with_temporary_moe_weights( - self, - gate_up_weight, - down_weight, - forward_native_moe_loop, - hidden_states.to(target_dtype), - top_k_index, - top_k_weights, - ) + return final_hidden_states.view(original_shape) @torch.compiler.disable def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): from .moe_utils import ( - _get_moe_weight_and_quant_state, select_moe_backend, forward_native_grouped_mm, forward_triton_grouped_gemm, @@ -708,10 +771,10 @@ def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights): # 2. Dequant FP8 weights to bf16/fp16 and run through normal MoE forward target_dtype = _get_fp8_dequant_target_dtype(hidden_states) - gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") - down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") - gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype) - down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype) + gate_up_base, gate_up_quant, gate_up_qkind = _get_moe_weight_and_quant_info(self, "gate_up_proj") + down_base, down_quant, down_qkind = _get_moe_weight_and_quant_info(self, "down_proj") + gate_up_weight = _dequantize_full_expert_weights(gate_up_base, gate_up_quant, target_dtype, quant_kind=gate_up_qkind) + down_weight = _dequantize_full_expert_weights(down_base, down_quant, target_dtype, quant_kind=down_qkind) if gate_up_weight is not None and down_weight is not None: if backend == "grouped_mm":