Fix transformers 5.x drift: FP8Linear constructor signature and Gemma3 double-BOS strip under padding#766
Fix transformers 5.x drift: FP8Linear constructor signature and Gemma3 double-BOS strip under padding#766danielhanchen wants to merge 1 commit into
Conversation
…rocessor patch Two independent breakages surfaced by running the GRPO notebooks on transformers 5.11: 1. convert_vllm_to_huggingface still passed dtype= to FP8Linear, whose 5.x signature dropped it (weight dtype is forced to the FP8 dtype internally). Replace the version branch with signature filtering so 4.x (bias/dtype/device) and 5.x (has_bias) both get exactly the kwargs they accept, and future drift degrades gracefully. 2. patch_Gemma3Processor stripped the double BOS after the tokenizer padded. Under left padding only the longest row still starts with [bos, bos], so that row alone got shorter, re-ragging the batch (ValueError: expected sequence of length N at dim 1) and leaving attention_mask out of sync. Tokenize unpadded, strip with attention_mask kept in line, then pad once via tokenizer.pad with only the kwargs its installed signature accepts. Verified: FP8Linear constructs through the filter on 4.57.6 and 5.11; Gemma3 4B vision GRPO goes from failing in 3 min to passing 4 GRPO steps in 12 min; both FP8 GRPO notebooks pass with this on top of the size-node fix from #695.
There was a problem hiding this comment.
Code Review
This pull request refactors tokenization padding logic in gemma.py to prevent attention mask desynchronization when stripping double BOS tokens, and updates vllm_utils.py to dynamically filter initialization arguments for FP8Linear based on its signature. Feedback on these changes highlights two issues: first, popping padding_side from text_kwargs causes it to be ignored during tokenizer.pad, which can be resolved by temporarily setting self.tokenizer.padding_side; second, passing both bias and has_bias (along with obsolete arguments like dtype or device) to FP8Linear can cause a TypeError if **kwargs is present, which should be resolved by dynamically selecting parameters based on the active transformers version and signature.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"): | ||
| pad_params = inspect.signature(self.tokenizer.pad).parameters | ||
| text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params}) |
There was a problem hiding this comment.
When popping padding_side from text_kwargs, it is excluded from pad_kwargs passed to tokenizer.pad because padding_side is not a parameter of the pad method. This causes any user-specified padding_side to be silently ignored, reverting to the tokenizer's default padding side. To fix this, temporarily set self.tokenizer.padding_side to the requested value during the padding operation and restore it afterward.
| if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"): | |
| pad_params = inspect.signature(self.tokenizer.pad).parameters | |
| text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params}) | |
| if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"): | |
| pad_params = inspect.signature(self.tokenizer.pad).parameters | |
| old_padding_side = self.tokenizer.padding_side | |
| if "padding_side" in pad_kwargs: | |
| self.tokenizer.padding_side = pad_kwargs["padding_side"] | |
| try: | |
| text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params}) | |
| finally: | |
| self.tokenizer.padding_side = old_padding_side |
| fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device()) | ||
| fp8_params = inspect.signature(FP8Linear.__init__).parameters | ||
| if not any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()): | ||
| fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params} |
There was a problem hiding this comment.
If FP8Linear.__init__ has a **kwargs parameter (i.e., VAR_KEYWORD), passing both bias and has_bias simultaneously can cause a TypeError when they are forwarded to super().__init__() (e.g., nn.Linear or nn.Module), as these base classes do not accept both arguments. Similarly, passing obsolete arguments like dtype or device on transformers 5.x can lead to failures. It is safer to dynamically select only the correct parameters based on the signature of FP8Linear.__init__ and the active transformers version.
| fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device()) | |
| fp8_params = inspect.signature(FP8Linear.__init__).parameters | |
| if not any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()): | |
| fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params} | |
| fp8_params = inspect.signature(FP8Linear.__init__).parameters | |
| has_var_keyword = any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()) | |
| fp8_kwargs = dict( | |
| in_features=0, | |
| out_features=0, | |
| block_size=kwargs['block_size'], | |
| activation_scheme=kwargs['activation_scheme'], | |
| ) | |
| if "has_bias" in fp8_params or (has_var_keyword and Version("transformers") >= Version("5.0.0")): | |
| fp8_kwargs["has_bias"] = has_bias | |
| else: | |
| fp8_kwargs["bias"] = has_bias | |
| if "dtype" in fp8_params or (has_var_keyword & Version("transformers") < Version("5.0.0")): | |
| fp8_kwargs["dtype"] = dtype | |
| if "device" in fp8_params or (has_var_keyword and Version("transformers") < Version("5.0.0")): | |
| fp8_kwargs["device"] = get_target_device() | |
| if not has_var_keyword: | |
| fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params} |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: eba86b826b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| pad_kwargs = {k: text_kwargs.pop(k) for k in ("padding", "max_length", "pad_to_multiple_of", "padding_side") if k in text_kwargs} | ||
| text_inputs = self.tokenizer(text=text, **text_kwargs) |
There was a problem hiding this comment.
Preserve max_length during tokenization
When callers pass max_length for truncation, this comprehension removes it before self.tokenizer(...) runs and only forwards it later to tokenizer.pad. Padding does not truncate already-tokenized sequences, and when padding is disabled the value is dropped entirely, so Gemma3 prompts that previously obeyed truncation=True, max_length=N can now exceed the requested/model length and break batching or run past context limits.
Useful? React with 👍 / 👎.
|
Re-validated on current main via an integration branch and a sim battery in isolated uv venvs (transformers 4.57.6 and 5.x). FP8Linear signature filter (the unique part of this PR): 16/16 simulation cases pass. The before-repro confirms the bug: the pre-fix positional/keyword construction raises Note on the Gemma3 double-BOS hunk: the maintainer's open zoo#695 independently fixes the same bug with a different approach (tokenize-with-padding then re-pad trimmed rows, vs this PR's tokenize-unpadded-strip-pad-once). Both are correct at the token-id level (verified across padding x side x batch combinations, 56/56 sim cases). Since 695 is likely to land first, the Gemma3 hunk here is redundant with it; happy to drop it and keep this PR to just the FP8Linear fix if that is cleaner for review. |
Two transformers 5.x drift fixes, both backwards compatible with 4.57.6.
FP8Linear constructor signature drift
transformers 4.57.6 builds
FP8Linear(in_features, out_features, bias, dtype, block_size, device, activation_scheme). transformers 5.x changed the signature to(in_features, out_features, block_size, activation_scheme, scale_fmt, has_bias):bias,dtype,deviceare gone andhas_biasarrived.vllm_utils.pypassed the 4.x layout positionally, so FP8 models crashed on 5.x with a TypeError.The fix builds a superset kwargs dict (
bias,has_bias,dtype,device,block_size,activation_scheme) and filters it againstinspect.signature(FP8Linear.__init__)before calling, with a VAR_KEYWORD guard so a future**kwargssignature receives everything. Works unchanged on 4.57.6, 5.x, and signatures in between.Gemma3 processor double-BOS under batched padding
The existing Gemma3 patch stripped a duplicated BOS by checking
input_ids[:2]after tokenization with padding. Under left padding only the longest row starts with[bos, bos], so the strip re-ragged the batch and desyncedattention_mask.The fix tokenizes unpadded, strips the duplicate BOS per row while keeping the attention mask in sync, then pads once via
tokenizer.padwith the pad kwargs filtered against its signature (transformers 5 changed which kwargspadaccepts).Validation
Unit tests for both paths pass on transformers 4.57.6 and 5.11 (same scripts, both versions).
End to end inside the unsloth/unsloth Docker validation image (B200, torch 2.10 cu128, vllm 0.19.1, transformers 5.11), notebooks from unslothai/notebooks run with these fixes applied:
Both fixes are runtime monkey patches scoped to the existing patch sites; no behavior change when the installed transformers matches the old signatures.