Skip to content

Respect GC for GRPO#69

Open
danielhanchen wants to merge 6 commits into
mainfrom
pr-5269-head
Open

Respect GC for GRPO#69
danielhanchen wants to merge 6 commits into
mainfrom
pr-5269-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Staging mirror of unslothai#5269

Original PR: unslothai#5269
Author: Datta0

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

@danielhanchen

Copy link
Copy Markdown
Member Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the grpo_trainer__generate_and_score_completions function in unsloth/models/rl_replacements.py to ensure that the use_gradient_checkpointing parameter is correctly passed to the model's training state based on the trainer's configuration. There are no review comments to address, and I have no additional feedback to provide.

@danielhanchen

Copy link
Copy Markdown
Member Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ensures that gradient checkpointing settings are correctly preserved when switching models between inference and training modes during generation. It implements logic to snapshot the gradient checkpointing state across modules before inference and restores that state when returning to training mode in Llama and RL model implementations. I have no feedback to provide.

…stores

Two sibling generation paths put the model into inference mode and then
unconditionally restored training with the for_training default, which
re-enabled gradient checkpointing even when the caller had it disabled:

- unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed
  onto every TRL *_trainer module that exposes unwrap_model_for_generation.
- unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate.

Snapshot the active gradient_checkpointing state from the model modules
before for_inference clears it, then thread the snapshot through the
matching for_training call. Same one-line restore semantics already used
by prepare_for_training_mode and the GRPO replacement at rl_replacements.py.

The for_training(...) call on each line is preserved; only the kwarg is
added. The pre-existing post-generate guards (the conditional restore in
unsloth_fast_generate and the finally restore in
unsloth_unwrap_model_for_generation) continue to run unchanged.
…n restores

Two follow-ups to the post-generate gradient_checkpointing restore:

1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
   unwrapped_model.gradient_checkpointing_disable() before yielding
   (trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
   previous snapshot was taken inside the with-block and therefore read
   the post-disable state, restoring for_training with
   use_gradient_checkpointing=False even when the caller had it on. Move
   the snapshot above the with-block so it observes the caller's
   pre-disable configuration.

2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
   collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
   default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
   2824/3314, loader.py:248/854) into a plain True. After generation, the
   restore would silently downgrade "unsloth" smart GC to standard HF GC.
   Replace any() with a value-preserving next((v for ... if v), False) so
   the actual mode value survives the round-trip.

The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants