Fix bugs in FP8 MoE support#554
Conversation
- B7: Support sharded safetensors (multi-shard) in GLM4 FP8 scale patching - B8: Add quant_kind param to _dequantize_expert_slice for weight_scale_inv reciprocal handling - B18: Guard fp8_linear import with try/except and dequant fallback - B19: Flatten 3D top_k_index/top_k_weights alongside hidden_states - B23: Fix _make_grouped_mm_rhs_column_major (was no-op double transpose, now weight.mT.contiguous()) - B25: Add act_fn fallback to F.silu when attribute missing - B26: Remove dead _forward_native_moe_loop_fp8 function - B10: Add 3D input reshape in forward_native_moe_loop - B12: Prefer generic backend over unconditional FP8 in get_forward_moe_backend; use forward_moe_backend as final fallback - Fix use_separated_lora to respect _should_use_separated_lora() instead of hardcoded True - Remove no-op try/except RuntimeError and duplicate comment
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refines and expands the FP8 Mixture-of-Experts (MoE) support within the Unsloth framework. It addresses several critical correctness, safety, and compatibility issues identified in previous FP8 MoE implementations. The changes introduce robust mechanisms for handling sharded FP8 weights, improve dequantization accuracy, enhance backend selection logic for MoE operations, and ensure proper data shape management, leading to a more stable and efficient FP8 MoE system. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive set of fixes and improvements for FP8 MoE support, enhancing correctness, safety, and compatibility. Key changes include robust handling of sharded safetensors, correct dequantization logic, safer patching mechanisms, and improved backend selection. The code is well-structured and the fixes align with the PR's objectives. I've identified one critical issue in the dequantization logic that could lead to silent correctness problems, which I've detailed in a review comment. Addressing this will make the PR solid.
| def _dequantize_expert_slice( | ||
| expert_weight: torch.Tensor, | ||
| expert_quant_state, | ||
| target_dtype: torch.dtype, | ||
| quant_kind=None, | ||
| ) -> Optional[torch.Tensor]: | ||
| """Dequantize one expert's FP8 weight to target_dtype using pure PyTorch.""" | ||
| from .moe_utils import _try_attach_block_size | ||
|
|
||
| if expert_weight.dtype != torch.float8_e4m3fn: | ||
| return expert_weight.to(target_dtype) | ||
|
|
||
| if expert_quant_state is None: | ||
| return expert_weight.to(target_dtype) | ||
|
|
||
| s = expert_quant_state | ||
| if not isinstance(s, torch.Tensor): | ||
| return expert_weight.to(target_dtype) | ||
|
|
||
| if quant_kind == "weight_scale_inv": | ||
| s = s.reciprocal() | ||
|
|
||
| w = expert_weight.to(target_dtype) | ||
|
|
||
| # Per-tensor scale | ||
| if s.numel() == 1: | ||
| return w * s.view(1, 1).to(target_dtype) | ||
|
|
||
| # Reshape 1D to column vector for per-row handling | ||
| if s.ndim == 1: | ||
| s = s.view(-1, 1) | ||
|
|
||
| # Per-row scale: (m, 1) | ||
| if s.ndim == 2 and s.shape[1] == 1: | ||
| # Per-sub-projection scalar scales (e.g. 2 scales for gate+up stacked weight). | ||
| if ( | ||
| s.shape[0] > 1 | ||
| and s.shape[0] < w.shape[0] | ||
| and w.shape[0] % s.shape[0] == 0 | ||
| ): | ||
| repeat_factor = w.shape[0] // s.shape[0] | ||
| s = s.repeat_interleave(repeat_factor, dim=0) | ||
|
|
||
| if w.shape[0] == s.shape[0]: | ||
| return w * s.to(target_dtype) | ||
| elif w.shape[1] == s.shape[0]: | ||
| return (w.t() * s.to(target_dtype)).t() | ||
| return w * s.to(target_dtype) | ||
|
|
||
| # Block scale: (ceil(m/bm), ceil(n/bn)) — expand to weight shape | ||
| if s.ndim == 2: | ||
| block_size = getattr(expert_weight, "block_size", None) or getattr(s, "block_size", None) | ||
| M, N = w.shape | ||
| p, q = s.shape | ||
|
|
||
| if block_size is not None and len(block_size) == 2: | ||
| bm, bn = block_size | ||
| # Check if scale is transposed | ||
| if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q: | ||
| if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p: | ||
| s = s.T.contiguous() | ||
| p, q = s.shape | ||
| else: | ||
| return expert_weight.to(target_dtype) | ||
| else: | ||
| # Infer block size from scale grid | ||
| bm = _ceil_div(M, p) | ||
| bn = _ceil_div(N, q) | ||
|
|
||
| s_expanded = s.to(target_dtype).repeat_interleave(bm, dim=0)[:M].repeat_interleave(bn, dim=1)[:, :N] | ||
| return w * s_expanded | ||
|
|
||
| return expert_weight.to(target_dtype) | ||
|
|
There was a problem hiding this comment.
The _dequantize_expert_slice function has multiple code paths that can lead to silent failures. When dequantization is not possible (e.g., missing quantization state, mismatched dimensions), the function currently returns expert_weight.to(target_dtype). This provides a tensor with the correct dtype but with incorrect, still-quantized values, which will cause silent correctness issues in downstream computations.
To prevent this, the function should return None in these failure scenarios. The calling function, _dequantize_full_expert_weights, is already designed to handle a None return by propagating the failure, ensuring that the issue is not ignored.
The specific failure paths that should be updated are:
- Line 340:
expert_quant_stateisNone. - Line 344:
expert_quant_stateis not atorch.Tensor. - Line 391: Block quantization scale dimensions do not match.
- Line 400: The final fallback
returnstatement.
def _dequantize_expert_slice(
expert_weight: torch.Tensor,
expert_quant_state,
target_dtype: torch.dtype,
quant_kind=None,
) -> Optional[torch.Tensor]:
"""Dequantize one expert's FP8 weight to target_dtype using pure PyTorch."""
from .moe_utils import _try_attach_block_size
if expert_weight.dtype != torch.float8_e4m3fn:
return expert_weight.to(target_dtype)
if expert_quant_state is None:
return None
s = expert_quant_state
if not isinstance(s, torch.Tensor):
return None
if quant_kind == "weight_scale_inv":
s = s.reciprocal()
w = expert_weight.to(target_dtype)
# Per-tensor scale
if s.numel() == 1:
return w * s.view(1, 1).to(target_dtype)
# Reshape 1D to column vector for per-row handling
if s.ndim == 1:
s = s.view(-1, 1)
# Per-row scale: (m, 1)
if s.ndim == 2 and s.shape[1] == 1:
# Per-sub-projection scalar scales (e.g. 2 scales for gate+up stacked weight).
if (
s.shape[0] > 1
and s.shape[0] < w.shape[0]
and w.shape[0] % s.shape[0] == 0
):
repeat_factor = w.shape[0] // s.shape[0]
s = s.repeat_interleave(repeat_factor, dim=0)
if w.shape[0] == s.shape[0]:
return w * s.to(target_dtype)
elif w.shape[1] == s.shape[0]:
return (w.t() * s.to(target_dtype)).t()
return w * s.to(target_dtype)
# Block scale: (ceil(m/bm), ceil(n/bn)) — expand to weight shape
if s.ndim == 2:
block_size = getattr(expert_weight, "block_size", None) or getattr(s, "block_size", None)
M, N = w.shape
p, q = s.shape
if block_size is not None and len(block_size) == 2:
bm, bn = block_size
# Check if scale is transposed
if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q:
if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p:
s = s.T.contiguous()
p, q = s.shape
else:
return None
else:
# Infer block size from scale grid
bm = _ceil_div(M, p)
bn = _ceil_div(N, q)
s_expanded = s.to(target_dtype).repeat_interleave(bm, dim=0)[:M].repeat_interleave(bn, dim=1)[:, :N]
return w * s_expanded
return None|
|
||
| _CACHED_FORWARD_MOE_BACKEND = None | ||
| _CACHED_MOE_UTILS_MODULE = None | ||
| _CACHED_MOE_UTILS_FP8_MODULE = None |
| module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_fp8" | ||
| module = sys.modules.get(module_name, None) | ||
| if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file: | ||
| _CACHED_MOE_UTILS_FP8_MODULE = module |
| module.__package__ = "unsloth_zoo.temporary_patches" | ||
| sys.modules[module_name] = module | ||
| spec.loader.exec_module(module) | ||
| _CACHED_MOE_UTILS_FP8_MODULE = module |
| return _TORCH_SCALED_GROUPED_MM_SUPPORTED | ||
|
|
||
| if not _TORCH_SCALED_GROUPED_MM_AVAILABLE: | ||
| _TORCH_SCALED_GROUPED_MM_SUPPORTED = False |
| _TORCH_SCALED_GROUPED_MM_SUPPORTED = False | ||
| return False | ||
| if not torch.cuda.is_available(): | ||
| _TORCH_SCALED_GROUPED_MM_SUPPORTED = False |
| # FP8 scaled_grouped_mm path to Hopper (SM 9.x) only for now. | ||
| major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) | ||
| if major != 9: | ||
| _TORCH_SCALED_GROUPED_MM_SUPPORTED = False |
| if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q: | ||
| if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p: | ||
| s = s.T.contiguous() | ||
| p, q = s.shape |
| if _ceil_div(M, bm) != p or _ceil_div(N, bn) != q: | ||
| if _ceil_div(M, bm) == q and _ceil_div(N, bn) == p: | ||
| s = s.T.contiguous() | ||
| p, q = s.shape |
| return | ||
| try: | ||
| tensor_like.block_size = block_size | ||
| except (AttributeError, RuntimeError): |
| return forward_moe_backend_fp8( | ||
| self, hidden_states, top_k_index, top_k_weights | ||
| ) | ||
| except ImportError: |
Summary
Fixes multiple bugs identified during review of the FP8 MoE PRs (unsloth #4312 + #548). These changes sit on top of the PR branch and address correctness, safety, and compatibility issues.
Companion PR: unslothai/unsloth#4348
Fixes in
moe_utils_fp8.pyquant_kindparameter to_dequantize_expert_sliceand_dequantize_full_expert_weightssoweight_scale_invscales getreciprocal()applied before dequantizationfp8_linearimport withtry/except ImportErrorand add dequant+F.linear fallbacktop_k_index/top_k_weightsalongsidehidden_statesto prevent shape mismatch in expert mask_make_grouped_mm_rhs_column_major-- was a no-op double transpose (transpose(-2,-1).contiguous().transpose(-2,-1)), now correctly usesweight.mT.contiguous()act_fnfallback toF.siluwhen module attribute is missing_forward_native_moe_loop_fp8function (never called)use_separated_lorato call_should_use_separated_lora()instead of hardcodedTrue(respectsUNSLOTH_MOE_LORA_MERGED=1)Fixes in
moe_utils.pyforward_native_moe_loopget_forward_moe_backend; use localforward_moe_backend(which auto-detects FP8) as final fallbacktry/except RuntimeError as e: raise eTest plan
UNSLOTH_MOE_LORA_MERGED=1is respected in scaled_grouped_mm path