Skip to content

Fix bugs in FP8 MoE support#554

Open
danielhanchen wants to merge 8 commits into
mainfrom
fix/fp8-moe-bug-fixes
Open

Fix bugs in FP8 MoE support#554
danielhanchen wants to merge 8 commits into
mainfrom
fix/fp8-moe-bug-fixes

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Fixes multiple bugs identified during review of the FP8 MoE PRs (unsloth #4312 + #548). These changes sit on top of the PR branch and address correctness, safety, and compatibility issues.

Companion PR: unslothai/unsloth#4348

Fixes in moe_utils_fp8.py

  • B7: Support sharded safetensors (multi-shard) in GLM4 FP8 scale patching. Pre-computes needed tensor keys, builds tensor-to-shard lookup, opens all needed shards via a multi-shard reader
  • B8: Add quant_kind parameter to _dequantize_expert_slice and _dequantize_full_expert_weights so weight_scale_inv scales get reciprocal() applied before dequantization
  • B18: Guard fp8_linear import with try/except ImportError and add dequant+F.linear fallback
  • B19: Flatten 3D top_k_index/top_k_weights alongside hidden_states to prevent shape mismatch in expert mask
  • B23: Fix _make_grouped_mm_rhs_column_major -- was a no-op double transpose (transpose(-2,-1).contiguous().transpose(-2,-1)), now correctly uses weight.mT.contiguous()
  • B25: Add act_fn fallback to F.silu when module attribute is missing
  • B26: Remove dead _forward_native_moe_loop_fp8 function (never called)
  • Fix use_separated_lora to call _should_use_separated_lora() instead of hardcoded True (respects UNSLOTH_MOE_LORA_MERGED=1)

Fixes in moe_utils.py

  • B10: Add 3D input reshape at start/end of forward_native_moe_loop
  • B12: Prefer generic backend over unconditional FP8 in get_forward_moe_backend; use local forward_moe_backend (which auto-detects FP8) as final fallback
  • Remove no-op try/except RuntimeError as e: raise e
  • Remove duplicate comment

Test plan

  • Syntax check all modified files
  • Verify imports work correctly
  • Run FP8 MoE model (GLM4) to verify scale patching works with sharded checkpoints
  • Run standard non-FP8 MoE model to verify no regression from B12 backend priority change
  • Verify UNSLOTH_MOE_LORA_MERGED=1 is respected in scaled_grouped_mm path

Datta0 and others added 8 commits March 15, 2026 13:49
- B7: Support sharded safetensors (multi-shard) in GLM4 FP8 scale patching
- B8: Add quant_kind param to _dequantize_expert_slice for weight_scale_inv
  reciprocal handling
- B18: Guard fp8_linear import with try/except and dequant fallback
- B19: Flatten 3D top_k_index/top_k_weights alongside hidden_states
- B23: Fix _make_grouped_mm_rhs_column_major (was no-op double transpose,
  now weight.mT.contiguous())
- B25: Add act_fn fallback to F.silu when attribute missing
- B26: Remove dead _forward_native_moe_loop_fp8 function
- B10: Add 3D input reshape in forward_native_moe_loop
- B12: Prefer generic backend over unconditional FP8 in
  get_forward_moe_backend; use forward_moe_backend as final fallback
- Fix use_separated_lora to respect _should_use_separated_lora() instead
  of hardcoded True
- Remove no-op try/except RuntimeError and duplicate comment
@chatgpt-codex-connector

Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refines and expands the FP8 Mixture-of-Experts (MoE) support within the Unsloth framework. It addresses several critical correctness, safety, and compatibility issues identified in previous FP8 MoE implementations. The changes introduce robust mechanisms for handling sharded FP8 weights, improve dequantization accuracy, enhance backend selection logic for MoE operations, and ensure proper data shape management, leading to a more stable and efficient FP8 MoE system.

Highlights

  • Enhanced FP8 MoE Support: Introduced comprehensive FP8 support for Mixture-of-Experts (MoE) models, including handling sharded safetensors for GLM4 FP8 scale patching.
  • Improved Dequantization Logic: Added a 'quant_kind' parameter to dequantization functions to correctly apply reciprocal scaling for 'weight_scale_inv'.
  • Robust FP8 Linear Fallback: Implemented a 'try/except' guard for 'fp8_linear' import, providing a dequantization and 'F.linear' fallback for broader compatibility.
  • Corrected MoE Input Handling: Ensured proper flattening of 3D 'top_k_index' and 'top_k_weights' alongside 'hidden_states' to prevent shape mismatches in expert masks.
  • Fixed Grouped MM Transpose: Corrected '_make_grouped_mm_rhs_column_major' to use 'weight.mT.contiguous()' for accurate column-major transformation, resolving a no-op double transpose.
  • Flexible Activation Function: Added 'F.silu' as a fallback for 'act_fn' when the module attribute is missing, improving robustness.
  • Optimized MoE Backend Selection: Prioritized the generic MoE backend over unconditional FP8, allowing the generic backend to auto-detect and dispatch to FP8 if applicable, and removed a dead FP8 function.
  • Refined LoRA Separation Logic: Ensured 'use_separated_lora' correctly calls '_should_use_separated_lora()' to respect the 'UNSLOTH_MOE_LORA_MERGED=1' setting.
  • Added 3D Input Reshaping: Implemented necessary 3D input reshaping at the start and end of 'forward_native_moe_loop' for consistent processing.
Changelog
  • unsloth_zoo/patching_utils.py
    • Imported 'maybe_patch_stacked_moe_expert_fp8_scales'.
    • Called 'maybe_patch_stacked_moe_expert_fp8_scales' on the model.
  • unsloth_zoo/temporary_patches/glm4_moe.py
    • Updated 'patch_function' calls for 'Glm4MoeLiteNaiveMoe' and 'Glm4MoeLiteMoE' to include 'force = True'.
    • Added comments explaining the need for 'force = True' due to recent transformer changes.
  • unsloth_zoo/temporary_patches/misc.py
    • Modified 'patch_transformers_masks' to ensure the '_compile' flag is explicitly set to 'False' for 'create_block_mask' when Unsloth compiles 'create_causal_mask'.
    • Added patching for 'torch.nn.attention.flex_attention.create_block_mask'.
  • unsloth_zoo/temporary_patches/moe_utils.py
    • Added logic to install 'moe_utils_fp8.py' to the cache if it exists.
    • Introduced '_CACHED_MOE_UTILS_FP8_MODULE' and '_load_cached_moe_utils_fp8_module' for FP8 module caching.
    • Updated module names for cached imports from 'unsloth_cached_moe_utils' to 'unsloth_zoo.temporary_patches._cached_moe_utils'.
    • Set 'module.package' for dynamically loaded modules.
    • Modified 'get_forward_moe_backend' to prefer the generic backend and then check for the FP8 backend as a fallback.
    • Added '_try_attach_block_size' helper function.
    • Added '_get_base_weight_and_quant_state' and '_get_moe_weight_and_quant_state' for retrieving weight and quantization state.
    • Added '_get_grouped_lora' and '_apply_grouped_lora' for handling grouped LoRA.
    • Added '_expand_grouped_bias' for expanding grouped biases.
    • Removed extensive comments and simplified logic within '_patched_param_wrapper_forward'.
    • Wrapped routing calculations in 'forward_native_grouped_mm' with 'torch.no_grad()'.
    • Removed a no-op 'try/except RuntimeError' block in 'forward_native_grouped_mm'.
    • Removed a duplicate comment 'Grouped GEMM 2: down projection' in 'forward_triton_grouped_gemm'.
    • Implemented 3D input reshaping and un-reshaping at the beginning and end of 'forward_native_moe_loop'.
  • unsloth_zoo/temporary_patches/moe_utils_fp8.py
    • Added a new file implementing comprehensive FP8 Mixture-of-Experts support.
    • Included functions for GLM4 FP8 scale patching with sharded safetensors ('_maybe_patch_glm4_stacked_moe_fp8_scales', 'maybe_patch_stacked_moe_expert_fp8_scales').
    • Defined utility functions like '_is_float8_tensor', '_get_fp8_dequant_target_dtype', '_log_moe_fp8_backend_once', '_check_torch_scaled_grouped_mm_supported'.
    • Implemented FP8 dequantization logic ('_slice_fp8_quant_state', '_ceil_div', '_dequantize_expert_slice', '_dequantize_full_expert_weights').
    • Provided weight preparation for grouped MM ('_make_grouped_mm_rhs_column_major', '_get_moe_weight_and_quant_info', '_extract_scaled_grouped_mm_weight_scale', '_prepare_scaled_grouped_mm_weight').
    • Included manual FP8 row-wise quantization ('_manual_fp8_rowwise_quantize').
    • Introduced '_forward_scaled_grouped_mm_fp8' for fast FP8 grouped MM on supported hardware.
    • Added '_moe_uses_fp8_expert_weights' to detect FP8 usage.
    • Provided '_call_with_temporary_moe_weights' for temporary weight swapping.
    • Implemented '_slice_fp8_linear_quant_state' and '_forward_native_fp8_expert_loop' as a fallback for per-expert FP8 linear operations.
    • Defined 'forward_moe_backend_fp8' as the main entry point for FP8 MoE, orchestrating different backend paths.
  • unsloth_zoo/temporary_patches/qwen3_moe.py
    • Captured 'input_dtype' at the start of 'sparse_moe_block_forward'.
    • Ensured 'router_input' is cast to 'gate_weight.dtype' before being passed to the gate.
    • Casted 'routing_weights' to 'router_input.dtype'.
    • Casted 'final_hidden_states' back to 'input_dtype' before returning.
Activity
  • The pull request was created to fix multiple bugs identified during the review of FP8 MoE PRs (#4312 + [MoE] FP8 support for MoE, specifically GLM 4.7 flash #548).
  • The author provided a detailed summary of fixes in 'moe_utils_fp8.py' and 'moe_utils.py'.
  • A test plan was outlined, including syntax checks, import verification, running FP8 MoE models (GLM4) with sharded checkpoints, verifying no regression for non-FP8 MoE, and checking 'UNSLOTH_MOE_LORA_MERGED=1' behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive set of fixes and improvements for FP8 MoE support, enhancing correctness, safety, and compatibility. Key changes include robust handling of sharded safetensors, correct dequantization logic, safer patching mechanisms, and improved backend selection. The code is well-structured and the fixes align with the PR's objectives. I've identified one critical issue in the dequantization logic that could lead to silent correctness problems, which I've detailed in a review comment. Addressing this will make the PR solid.

Comment on lines +328 to +401
def _dequantize_expert_slice(
expert_weight: torch.Tensor,
expert_quant_state,
target_dtype: torch.dtype,
quant_kind=None,
) -> Optional[torch.Tensor]:
"""Dequantize one expert's FP8 weight to target_dtype using pure PyTorch."""
from .moe_utils import _try_attach_block_size

if expert_weight.dtype != torch.float8_e4m3fn:
return expert_weight.to(target_dtype)

if expert_quant_state is None:
return expert_weight.to(target_dtype)

s = expert_quant_state
if not isinstance(s, torch.Tensor):
return expert_weight.to(target_dtype)

if quant_kind == "weight_scale_inv":
s = s.reciprocal()

w = expert_weight.to(target_dtype)

# Per-tensor scale
if s.numel() == 1:
return w * s.view(1, 1).to(target_dtype)

# Reshape 1D to column vector for per-row handling
if s.ndim == 1:
s = s.view(-1, 1)

# Per-row scale: (m, 1)
if s.ndim == 2 and s.shape[1] == 1:
# Per-sub-projection scalar scales (e.g. 2 scales for gate+up stacked weight).
if (
s.shape[0] > 1
and s.shape[0] < w.shape[0]
and w.shape[0] % s.shape[0] == 0
):
repeat_factor = w.shape[0] // s.shape[0]
s = s.repeat_interleave(repeat_factor, dim=0)

if w.shape[0] == s.shape[0]:
return w * s.to(target_dtype)
elif w.shape[1] == s.shape[0]:
return (w.t() * s.to(target_dtype)).t()
return w * s.to(target_dtype)

# Block scale: (ceil(m/bm), ceil(n/bn)) — expand to weight shape
if s.ndim == 2:
block_size = getattr(expert_weight, "block_size", None) or getattr(s, "block_size", None)
M, N = w.shape
p, q = s.shape

if block_size is not None and len(block_size) == 2:
bm, bn = block_size
# Check if scale is transposed
if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q:
if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p:
s = s.T.contiguous()
p, q = s.shape
else:
return expert_weight.to(target_dtype)
else:
# Infer block size from scale grid
bm = _ceil_div(M, p)
bn = _ceil_div(N, q)

s_expanded = s.to(target_dtype).repeat_interleave(bm, dim=0)[:M].repeat_interleave(bn, dim=1)[:, :N]
return w * s_expanded

return expert_weight.to(target_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _dequantize_expert_slice function has multiple code paths that can lead to silent failures. When dequantization is not possible (e.g., missing quantization state, mismatched dimensions), the function currently returns expert_weight.to(target_dtype). This provides a tensor with the correct dtype but with incorrect, still-quantized values, which will cause silent correctness issues in downstream computations.

To prevent this, the function should return None in these failure scenarios. The calling function, _dequantize_full_expert_weights, is already designed to handle a None return by propagating the failure, ensuring that the issue is not ignored.

The specific failure paths that should be updated are:

  • Line 340: expert_quant_state is None.
  • Line 344: expert_quant_state is not a torch.Tensor.
  • Line 391: Block quantization scale dimensions do not match.
  • Line 400: The final fallback return statement.
def _dequantize_expert_slice(
    expert_weight: torch.Tensor,
    expert_quant_state,
    target_dtype: torch.dtype,
    quant_kind=None,
) -> Optional[torch.Tensor]:
    """Dequantize one expert's FP8 weight to target_dtype using pure PyTorch."""
    from .moe_utils import _try_attach_block_size

    if expert_weight.dtype != torch.float8_e4m3fn:
        return expert_weight.to(target_dtype)

    if expert_quant_state is None:
        return None

    s = expert_quant_state
    if not isinstance(s, torch.Tensor):
        return None

    if quant_kind == "weight_scale_inv":
        s = s.reciprocal()

    w = expert_weight.to(target_dtype)

    # Per-tensor scale
    if s.numel() == 1:
        return w * s.view(1, 1).to(target_dtype)

    # Reshape 1D to column vector for per-row handling
    if s.ndim == 1:
        s = s.view(-1, 1)

    # Per-row scale: (m, 1)
    if s.ndim == 2 and s.shape[1] == 1:
        # Per-sub-projection scalar scales (e.g. 2 scales for gate+up stacked weight).
        if (
            s.shape[0] > 1
            and s.shape[0] < w.shape[0]
            and w.shape[0] % s.shape[0] == 0
        ):
            repeat_factor = w.shape[0] // s.shape[0]
            s = s.repeat_interleave(repeat_factor, dim=0)

        if w.shape[0] == s.shape[0]:
            return w * s.to(target_dtype)
        elif w.shape[1] == s.shape[0]:
            return (w.t() * s.to(target_dtype)).t()
        return w * s.to(target_dtype)

    # Block scale: (ceil(m/bm), ceil(n/bn)) — expand to weight shape
    if s.ndim == 2:
        block_size = getattr(expert_weight, "block_size", None) or getattr(s, "block_size", None)
        M, N = w.shape
        p, q = s.shape

        if block_size is not None and len(block_size) == 2:
            bm, bn = block_size
            # Check if scale is transposed
            if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q:
                if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p:
                    s = s.T.contiguous()
                    p, q = s.shape
                else:
                    return None
        else:
            # Infer block size from scale grid
            bm = _ceil_div(M, p)
            bn = _ceil_div(N, q)

        s_expanded = s.to(target_dtype).repeat_interleave(bm, dim=0)[:M].repeat_interleave(bn, dim=1)[:, :N]
        return w * s_expanded

    return None


_CACHED_FORWARD_MOE_BACKEND = None
_CACHED_MOE_UTILS_MODULE = None
_CACHED_MOE_UTILS_FP8_MODULE = None
module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_fp8"
module = sys.modules.get(module_name, None)
if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
_CACHED_MOE_UTILS_FP8_MODULE = module
module.__package__ = "unsloth_zoo.temporary_patches"
sys.modules[module_name] = module
spec.loader.exec_module(module)
_CACHED_MOE_UTILS_FP8_MODULE = module
return _TORCH_SCALED_GROUPED_MM_SUPPORTED

if not _TORCH_SCALED_GROUPED_MM_AVAILABLE:
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
return False
if not torch.cuda.is_available():
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
# FP8 scaled_grouped_mm path to Hopper (SM 9.x) only for now.
major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device())
if major != 9:
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q:
if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p:
s = s.T.contiguous()
p, q = s.shape
if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q:
if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p:
s = s.T.contiguous()
p, q = s.shape
return
try:
tensor_like.block_size = block_size
except (AttributeError, RuntimeError):
return forward_moe_backend_fp8(
self, hidden_states, top_k_index, top_k_weights
)
except ImportError:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants