Skip to content

[tests] Review tests for PR #5269#72

Open
danielhanchen wants to merge 6 commits into
mainfrom
pr-5269-tests
Open

[tests] Review tests for PR #5269#72
danielhanchen wants to merge 6 commits into
mainfrom
pr-5269-tests

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Automated test files from review process

Datta0 and others added 6 commits May 4, 2026 09:46
…stores

Two sibling generation paths put the model into inference mode and then
unconditionally restored training with the for_training default, which
re-enabled gradient checkpointing even when the caller had it disabled:

- unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed
  onto every TRL *_trainer module that exposes unwrap_model_for_generation.
- unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate.

Snapshot the active gradient_checkpointing state from the model modules
before for_inference clears it, then thread the snapshot through the
matching for_training call. Same one-line restore semantics already used
by prepare_for_training_mode and the GRPO replacement at rl_replacements.py.

The for_training(...) call on each line is preserved; only the kwarg is
added. The pre-existing post-generate guards (the conditional restore in
unsloth_fast_generate and the finally restore in
unsloth_unwrap_model_for_generation) continue to run unchanged.
…n restores

Two follow-ups to the post-generate gradient_checkpointing restore:

1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
   unwrapped_model.gradient_checkpointing_disable() before yielding
   (trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
   previous snapshot was taken inside the with-block and therefore read
   the post-disable state, restoring for_training with
   use_gradient_checkpointing=False even when the caller had it on. Move
   the snapshot above the with-block so it observes the caller's
   pre-disable configuration.

2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
   collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
   default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
   2824/3314, loader.py:248/854) into a plain True. After generation, the
   restore would silently downgrade "unsloth" smart GC to standard HF GC.
   Replace any() with a value-preserving next((v for ... if v), False) so
   the actual mode value survives the round-trip.

The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants