Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions unsloth_zoo/temporary_patches/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,23 @@ def _gemma3_call_impl(
# text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", True)

text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
# Fix double BOS tokens
# Tokenize WITHOUT padding: stripping a double BOS after padding only
# shortens rows that still start with [bos, bos] (under left padding,
# just the longest row), re-ragging the batch and desyncing
# attention_mask. Strip first, then pad once.
text_kwargs = dict(output_kwargs["text_kwargs"])
pad_kwargs = {k: text_kwargs.pop(k) for k in ("padding", "max_length", "pad_to_multiple_of", "padding_side") if k in text_kwargs}
text_inputs = self.tokenizer(text=text, **text_kwargs)
Comment on lines +197 to +198

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve max_length during tokenization

When callers pass max_length for truncation, this comprehension removes it before self.tokenizer(...) runs and only forwards it later to tokenizer.pad. Padding does not truncate already-tokenized sequences, and when padding is disabled the value is dropped entirely, so Gemma3 prompts that previously obeyed truncation=True, max_length=N can now exceed the requested/model length and break batching or run past context limits.

Useful? React with 👍 / 👎.

# Fix double BOS tokens, keeping attention_mask in sync
double_bos_token_id = [self.tokenizer.bos_token_id]*2
input_ids = text_inputs["input_ids"]
text_inputs["input_ids"] = [x[1:] if x[:2] == double_bos_token_id else x for x in input_ids]
stripped = [x[1:] if x[:2] == double_bos_token_id else x for x in input_ids]
if "attention_mask" in text_inputs:
text_inputs["attention_mask"] = [m[1:] if len(k) != len(x) else m for x, k, m in zip(input_ids, stripped, text_inputs["attention_mask"])]
text_inputs["input_ids"] = stripped
if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"):
pad_params = inspect.signature(self.tokenizer.pad).parameters
text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params})
Comment on lines +206 to +208

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When popping padding_side from text_kwargs, it is excluded from pad_kwargs passed to tokenizer.pad because padding_side is not a parameter of the pad method. This causes any user-specified padding_side to be silently ignored, reverting to the tokenizer's default padding side. To fix this, temporarily set self.tokenizer.padding_side to the requested value during the padding operation and restore it afterward.

Suggested change
if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"):
pad_params = inspect.signature(self.tokenizer.pad).parameters
text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params})
if pad_kwargs.get("padding", False) not in (False, None, "do_not_pad"):
pad_params = inspect.signature(self.tokenizer.pad).parameters
old_padding_side = self.tokenizer.padding_side
if "padding_side" in pad_kwargs:
self.tokenizer.padding_side = pad_kwargs["padding_side"]
try:
text_inputs = self.tokenizer.pad(text_inputs, **{k: v for k, v in pad_kwargs.items() if k in pad_params})
finally:
self.tokenizer.padding_side = old_padding_side


# Add token type ids manually, as tokenizer can't do arbitrary position token types
# [TODO] FAILS for batched tokens since text_inputs["input_ids"] is a list of lists, so np.array creates an object!
Expand Down
13 changes: 7 additions & 6 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,12 +1426,13 @@ def _override_to(self, *args, **kwargs):
layer.weight.input_scale_ub = kwargs['input_scale_ub']
layer.quant_method = "fbgemm_fp8"
elif fp8_weight_scale.ndim == 2:
# FP8 dynamic quantized. transformers 5.0+ renamed
# bias -> has_bias and removed device.
if Version("transformers") < Version("5.0.0"):
fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device())
else:
fp8_kwargs = dict(in_features=0, out_features=0, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'])
# FP8 dynamic quantized. FP8Linear's signature drifts across
# transformers versions (4.x: bias/dtype/device; 5.x:
# has_bias, no dtype/device), so keep only accepted kwargs.
fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device())
fp8_params = inspect.signature(FP8Linear.__init__).parameters
if not any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()):
fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params}
Comment on lines +1432 to +1435

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If FP8Linear.__init__ has a **kwargs parameter (i.e., VAR_KEYWORD), passing both bias and has_bias simultaneously can cause a TypeError when they are forwarded to super().__init__() (e.g., nn.Linear or nn.Module), as these base classes do not accept both arguments. Similarly, passing obsolete arguments like dtype or device on transformers 5.x can lead to failures. It is safer to dynamically select only the correct parameters based on the signature of FP8Linear.__init__ and the active transformers version.

Suggested change
fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, has_bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme'], device=get_target_device())
fp8_params = inspect.signature(FP8Linear.__init__).parameters
if not any(p.kind is p.VAR_KEYWORD for p in fp8_params.values()):
fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params}
fp8_params = inspect.signature(FP8Linear.__init__).parameters
has_var_keyword = any(p.kind is p.VAR_KEYWORD for p in fp8_params.values())
fp8_kwargs = dict(
in_features=0,
out_features=0,
block_size=kwargs['block_size'],
activation_scheme=kwargs['activation_scheme'],
)
if "has_bias" in fp8_params or (has_var_keyword and Version("transformers") >= Version("5.0.0")):
fp8_kwargs["has_bias"] = has_bias
else:
fp8_kwargs["bias"] = has_bias
if "dtype" in fp8_params or (has_var_keyword & Version("transformers") < Version("5.0.0")):
fp8_kwargs["dtype"] = dtype
if "device" in fp8_params or (has_var_keyword and Version("transformers") < Version("5.0.0")):
fp8_kwargs["device"] = get_target_device()
if not has_var_keyword:
fp8_kwargs = {k: v for k, v in fp8_kwargs.items() if k in fp8_params}

layer = FP8Linear(**fp8_kwargs)
layer.in_features = weight.shape[1]
layer.out_features = weight.shape[0]
Expand Down
Loading