[tests] Review tests for PR #5269#72
Open
danielhanchen wants to merge 6 commits into
Open
Conversation
…stores Two sibling generation paths put the model into inference mode and then unconditionally restored training with the for_training default, which re-enabled gradient checkpointing even when the caller had it disabled: - unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed onto every TRL *_trainer module that exposes unwrap_model_for_generation. - unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate. Snapshot the active gradient_checkpointing state from the model modules before for_inference clears it, then thread the snapshot through the matching for_training call. Same one-line restore semantics already used by prepare_for_training_mode and the GRPO replacement at rl_replacements.py. The for_training(...) call on each line is preserved; only the kwarg is added. The pre-existing post-generate guards (the conditional restore in unsloth_fast_generate and the finally restore in unsloth_unwrap_model_for_generation) continue to run unchanged.
…n restores
Two follow-ups to the post-generate gradient_checkpointing restore:
1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
unwrapped_model.gradient_checkpointing_disable() before yielding
(trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
previous snapshot was taken inside the with-block and therefore read
the post-disable state, restoring for_training with
use_gradient_checkpointing=False even when the caller had it on. Move
the snapshot above the with-block so it observes the caller's
pre-disable configuration.
2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
2824/3314, loader.py:248/854) into a plain True. After generation, the
restore would silently downgrade "unsloth" smart GC to standard HF GC.
Replace any() with a value-preserving next((v for ... if v), False) so
the actual mode value survives the round-trip.
The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
324e30e to
89cedfe
Compare
e128c6f to
1555c15
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Automated test files from review process