Respect GC for GRPO#69
Conversation
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request updates the grpo_trainer__generate_and_score_completions function in unsloth/models/rl_replacements.py to ensure that the use_gradient_checkpointing parameter is correctly passed to the model's training state based on the trainer's configuration. There are no review comments to address, and I have no additional feedback to provide.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request ensures that gradient checkpointing settings are correctly preserved when switching models between inference and training modes during generation. It implements logic to snapshot the gradient checkpointing state across modules before inference and restores that state when returning to training mode in Llama and RL model implementations. I have no feedback to provide.
…stores Two sibling generation paths put the model into inference mode and then unconditionally restored training with the for_training default, which re-enabled gradient checkpointing even when the caller had it disabled: - unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed onto every TRL *_trainer module that exposes unwrap_model_for_generation. - unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate. Snapshot the active gradient_checkpointing state from the model modules before for_inference clears it, then thread the snapshot through the matching for_training call. Same one-line restore semantics already used by prepare_for_training_mode and the GRPO replacement at rl_replacements.py. The for_training(...) call on each line is preserved; only the kwarg is added. The pre-existing post-generate guards (the conditional restore in unsloth_fast_generate and the finally restore in unsloth_unwrap_model_for_generation) continue to run unchanged.
…n restores
Two follow-ups to the post-generate gradient_checkpointing restore:
1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
unwrapped_model.gradient_checkpointing_disable() before yielding
(trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
previous snapshot was taken inside the with-block and therefore read
the post-disable state, restoring for_training with
use_gradient_checkpointing=False even when the caller had it on. Move
the snapshot above the with-block so it observes the caller's
pre-disable configuration.
2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
2824/3314, loader.py:248/854) into a plain True. After generation, the
restore would silently downgrade "unsloth" smart GC to standard HF GC.
Replace any() with a value-preserving next((v for ... if v), False) so
the actual mode value survives the round-trip.
The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
324e30e to
89cedfe
Compare
e128c6f to
1555c15
Compare
Staging mirror of unslothai#5269
Original PR: unslothai#5269
Author: Datta0
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description