Skip to content

Fix transformers 5.x drift: FP8Linear constructor signature and Gemma3 double-BOS strip under padding#766

Open
danielhanchen wants to merge 1 commit into
mainfrom
fix-transformers5-fp8-gemma
Open

Fix transformers 5.x drift: FP8Linear constructor signature and Gemma3 double-BOS strip under padding#766
danielhanchen wants to merge 1 commit into
mainfrom
fix-transformers5-fp8-gemma

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Two transformers 5.x drift fixes, both backwards compatible with 4.57.6.

FP8Linear constructor signature drift

transformers 4.57.6 builds FP8Linear(in_features, out_features, bias, dtype, block_size, device, activation_scheme). transformers 5.x changed the signature to (in_features, out_features, block_size, activation_scheme, scale_fmt, has_bias): bias, dtype, device are gone and has_bias arrived. vllm_utils.py passed the 4.x layout positionally, so FP8 models crashed on 5.x with a TypeError.

The fix builds a superset kwargs dict (bias, has_bias, dtype, device, block_size, activation_scheme) and filters it against inspect.signature(FP8Linear.__init__) before calling, with a VAR_KEYWORD guard so a future **kwargs signature receives everything. Works unchanged on 4.57.6, 5.x, and signatures in between.

Gemma3 processor double-BOS under batched padding

The existing Gemma3 patch stripped a duplicated BOS by checking input_ids[:2] after tokenization with padding. Under left padding only the longest row starts with [bos, bos], so the strip re-ragged the batch and desynced attention_mask.

The fix tokenizes unpadded, strips the duplicate BOS per row while keeping the attention mask in sync, then pads once via tokenizer.pad with the pad kwargs filtered against its signature (transformers 5 changed which kwargs pad accepts).

Validation

Unit tests for both paths pass on transformers 4.57.6 and 5.11 (same scripts, both versions).

End to end inside the unsloth/unsloth Docker validation image (B200, torch 2.10 cu128, vllm 0.19.1, transformers 5.11), notebooks from unslothai/notebooks run with these fixes applied:

Notebook Result
Gemma3_(4B)-Vision-GRPO pass, 13.6 min
Llama_FP8_GRPO pass, 8.6 min
Qwen3_8B_FP8_GRPO pass, 12.1 min
Qwen3_(4B)-GRPO (control, no FP8/Gemma path) pass, 11.9 min

Both fixes are runtime monkey patches scoped to the existing patch sites; no behavior change when the installed transformers matches the old signatures.

…rocessor patch

Two independent breakages surfaced by running the GRPO notebooks on
transformers 5.11:

1. convert_vllm_to_huggingface still passed dtype= to FP8Linear, whose
   5.x signature dropped it (weight dtype is forced to the FP8 dtype
   internally). Replace the version branch with signature filtering so
   4.x (bias/dtype/device) and 5.x (has_bias) both get exactly the
   kwargs they accept, and future drift degrades gracefully.

2. patch_Gemma3Processor stripped the double BOS after the tokenizer
   padded. Under left padding only the longest row still starts with
   [bos, bos], so that row alone got shorter, re-ragging the batch
   (ValueError: expected sequence of length N at dim 1) and leaving
   attention_mask out of sync. Tokenize unpadded, strip with
   attention_mask kept in line, then pad once via tokenizer.pad with
   only the kwargs its installed signature accepts.

Verified: FP8Linear constructs through the filter on 4.57.6 and 5.11;
Gemma3 4B vision GRPO goes from failing in 3 min to passing 4 GRPO
steps in 12 min; both FP8 GRPO notebooks pass with this on top of the
size-node fix from #695.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors tokenization padding logic in gemma.py to prevent attention mask desynchronization when stripping double BOS tokens, and updates vllm_utils.py to dynamically filter initialization arguments for FP8Linear based on its signature. Feedback on these changes highlights two issues: first, popping padding_side from text_kwargs causes it to be ignored during tokenizer.pad, which can be resolved by temporarily setting self.tokenizer.padding_side; second, passing both bias and has_bias (along with obsolete arguments like dtype or device) to FP8Linear can cause a TypeError if **kwargs is present, which should be resolved by dynamically selecting parameters based on the active transformers version and signature.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +206 to +208
if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"):
pad_params = inspect.signature(self.tokenizer.pad).parameters
text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params})

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When popping padding_side from text_kwargs, it is excluded from pad_kwargs passed to tokenizer.pad because padding_side is not a parameter of the pad method. This causes any user-specified padding_side to be silently ignored, reverting to the tokenizer's default padding side. To fix this, temporarily set self.tokenizer.padding_side to the requested value during the padding operation and restore it afterward.

Suggested change
if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"):
pad_params = inspect.signature(self.tokenizer.pad).parameters
text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params})
if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"):
pad_params = inspect.signature(self.tokenizer.pad).parameters
old_padding_side = self.tokenizer.padding_side
if "padding_side" in pad_kwargs:
self.tokenizer.padding_side = pad_kwargs["padding_side"]
try:
text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params})
finally:
self.tokenizer.padding_side = old_padding_side

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +1432 to +1435
fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device())
fp8_params = inspect.signature(FP8Linear.__init__).parameters
if not any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()):
fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If FP8Linear.__init__ has a **kwargs parameter (i.e., VAR_KEYWORD), passing both bias and has_bias simultaneously can cause a TypeError when they are forwarded to super().__init__() (e.g., nn.Linear or nn.Module), as these base classes do not accept both arguments. Similarly, passing obsolete arguments like dtype or device on transformers 5.x can lead to failures. It is safer to dynamically select only the correct parameters based on the signature of FP8Linear.__init__ and the active transformers version.

Suggested change
fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device())
fp8_params = inspect.signature(FP8Linear.__init__).parameters
if not any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()):
fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params}
fp8_params = inspect.signature(FP8Linear.__init__).parameters
has_var_keyword = any(p.kind is p.VAR_KEYWORD for p in fp8_params.values())
fp8_kwargs = dict(
in_features=0,
out_features=0,
block_size=kwargs['block_size'],
activation_scheme=kwargs['activation_scheme'],
)
if "has_bias" in fp8_params or (has_var_keyword and Version("transformers") >= Version("5.0.0")):
fp8_kwargs["has_bias"] = has_bias
else:
fp8_kwargs["bias"] = has_bias
if "dtype" in fp8_params or (has_var_keyword & Version("transformers") < Version("5.0.0")):
fp8_kwargs["dtype"] = dtype
if "device" in fp8_params or (has_var_keyword and Version("transformers") < Version("5.0.0")):
fp8_kwargs["device"] = get_target_device()
if not has_var_keyword:
fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params}

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: eba86b826b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +197 to +198
pad_kwargs = {k: text_kwargs.pop(k) for k in ("padding", "max_length", "pad_to_multiple_of", "padding_side") if k in text_kwargs}
text_inputs = self.tokenizer(text=text, **text_kwargs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve max_length during tokenization

When callers pass max_length for truncation, this comprehension removes it before self.tokenizer(...) runs and only forwards it later to tokenizer.pad. Padding does not truncate already-tokenized sequences, and when padding is disabled the value is dropped entirely, so Gemma3 prompts that previously obeyed truncation=True, max_length=N can now exceed the requested/model length and break batching or run past context limits.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Re-validated on current main via an integration branch and a sim battery in isolated uv venvs (transformers 4.57.6 and 5.x).

FP8Linear signature filter (the unique part of this PR): 16/16 simulation cases pass. The before-repro confirms the bug: the pre-fix positional/keyword construction raises TypeError: unexpected keyword argument 'dtype' against the real transformers 5.x FP8Linear signature; the filter constructs correctly against the 4.57.6 layout, the 5.x layout, a **kwargs signature, a keyword-only signature, and a future-extra-arg signature, with has_bias True and False. End to end, the FP8 GRPO notebooks (Llama_FP8_GRPO, Qwen3_8B_FP8_GRPO) train to completion on a B200 image carrying this fix.

Note on the Gemma3 double-BOS hunk: the maintainer's open zoo#695 independently fixes the same bug with a different approach (tokenize-with-padding then re-pad trimmed rows, vs this PR's tokenize-unpadded-strip-pad-once). Both are correct at the token-id level (verified across padding x side x batch combinations, 56/56 sim cases). Since 695 is likely to land first, the Gemma3 hunk here is redundant with it; happy to drop it and keep this PR to just the FP8Linear fix if that is cleaner for review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant