Fix model training state restoration in GRPO trainer#3754
Fix model training state restoration in GRPO trainer#3754danielhanchen merged 7 commits intounslothai:mainfrom
Conversation
Summary of ChangesHello @numb3r33, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an issue within the GRPO trainer where the model's training state was not consistently restored after certain operations. The changes ensure that the model's original training or inference mode is preserved and correctly reinstated, preventing potential side effects on subsequent training or evaluation steps and maintaining predictable model behavior. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a bug in the GRPO trainer where the model's training state was not correctly restored. The changes involve saving the model's training state before it's modified and then restoring it before the function returns. This ensures the model is in the correct mode (training or inference) after completions are generated and scored. The fix is logical and well-implemented. I have one suggestion to improve the maintainability of the code patching by using a more robust method.
unsloth/models/rl_replacements.py
Outdated
| function = function.replace( | ||
| " return output", | ||
| """ if not _was_training: | ||
| self.model.for_inference() | ||
| return output""", | ||
| ) |
There was a problem hiding this comment.
Using function.replace with a hardcoded indentation for return output is brittle. If the indentation of this line changes in a future version of trl, this patch will silently fail. It's more robust to use re.sub to capture the existing indentation and apply it to the replacement code. This approach is more maintainable and consistent with other patching logic in this file.
| function = function.replace( | |
| " return output", | |
| """ if not _was_training: | |
| self.model.for_inference() | |
| return output""", | |
| ) | |
| restoration_code = r'''if not _was_training: | |
| self.model.for_inference() | |
| return output''' | |
| function = re.sub( | |
| r'^(\s*)return output$', | |
| lambda m: textwrap.indent(restoration_code, m.group(1)), | |
| function, | |
| flags=re.MULTILINE, | |
| ) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
unsloth/models/rl_replacements.py
Outdated
| function = function.replace( | ||
| " return output", | ||
| """ if not _was_training: | ||
| self.model.for_inference() | ||
| return output""", |
There was a problem hiding this comment.
Guard restoring inference when _was_training is absent
The new replacement injects a if not _was_training: self.model.for_inference() block but it is applied unconditionally to every _generate_and_score_completions body. _was_training is only introduced when the earlier line_to_replace substitution succeeds (lines ~260-274); if the upstream TRL implementation changes and that exact batch_size line is not matched, the return statement will still be rewritten to reference _was_training, leading to a NameError when the trainer exits instead of simply skipping the restoration. Consider checking that the batch_size replacement succeeded or guarding the use of _was_training so future TRL versions don’t crash.
Useful? React with 👍 / 👎.
unsloth/models/rl_replacements.py
Outdated
|
|
||
| function = function.replace( | ||
| " return output", | ||
| """ if not _was_training: |
There was a problem hiding this comment.
Is the indentation here appropriate? Can you please show the final generated code in UnslothGRPOTrainer.py (a screenshot perhaps would help, add it to the PR description)
Also please add your investigation and the summary from the other comment to here in PR desc
There was a problem hiding this comment.
thanks @Datta0 i have updated the PR description with additional information and screenshot of the generated code, please have a look and let me know if any info is needed. Thanks again for reviewing
There was a problem hiding this comment.
Yeah thanks for that.
also if we can dynamically infer the indent length instead of hard coding? If you can do that, it'd be great
There was a problem hiding this comment.
Sure @Datta0 makes sense, I have modified the logic to use re
match = re.search(r"^(\s*)return output", function, re.MULTILINE)
if match:
indent = match.group(1)
new_code = indent + "if not _was_training:\n" + indent + " self.model.for_inference()\n" + indent + "return output"
function = function.replace(f"{indent}return output", new_code)
Please have a look.
Store the model's training state before generation and restore inference mode after completion if the model wasn't originally in training mode. This ensures the model returns to the correct state after generate and score operations.
for more information, see https://pre-commit.ci
Use regex to dynamically detect and preserve the original indentation when replacing the 'return output' statement, instead of hardcoding spaces. This ensures the patched code maintains consistent indentation regardless of the original formatting.
for more information, see https://pre-commit.ci
Replace f-string triple-quoted approach with explicit newline characters for clearer string construction in the grpo_trainer patch.
for more information, see https://pre-commit.ci
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request addresses a bug in the GRPO trainer where the model's training state was not correctly restored after evaluation, leading to incorrect metric logging. The fix involves saving the model's training state at the beginning of _generate_and_score_completions and restoring it before the function returns. The implementation is correct and effectively solves the issue. I've added one suggestion to improve the robustness of the code patching by using re.sub instead of str.replace to prevent potential unintended side effects.
| + indent | ||
| + "return output" | ||
| ) | ||
| function = function.replace(f"{indent}return output", new_code) |
There was a problem hiding this comment.
Using str.replace() can be risky here as it will replace all occurrences of f"{indent}return output". If this string appears multiple times in the function's source code, it could lead to unintended modifications. It's safer to use re.sub() with count=1 to ensure only the first match (which is the final return statement you're targeting) is replaced.
| function = function.replace(f"{indent}return output", new_code) | |
| function = re.sub(r"^\s*return output", new_code, function, count=1, flags=re.MULTILINE) |
|
@Datta0 happy new year, want to check if this is good to merge and is there anything required from my side? |
|
Hey @numb3r303 I think this is good. Just waiting for @danielhanchen to get some time to merge :) |
GRPO: restore model mode after generate (stacked on #3754)
Fixes #3744
@danielhanchen @shimmyshimmer
I did some debugging and found that during evaluation,
Trainer.evaluate()correctly callsmodel.eval()to set model.training=False. However, Unsloth's patched_generate_and_score_completionsfunction callsself.model.for_training()(rl_replacements.py#L267, L273) but never restores the original mode afterward. This leavesmodel.training=Trueeven during evaluation.When TRL's log() method runs, it checks
self.model.trainingto determine which metrics bucket to use. Sincemodel.training=True, eval metrics get logged to the "train" bucket instead of "eval".Proposed Fix:
Save the original training state (
_was_training = self.model.training) before calling for_training() (around L260)Restore it at the end of
_generate_and_score_completionsbefore return output.Example:
Modify the existing replacement_lines to save the mode
Add this new replacement near the end of the function (before return function):
I've tested this fix locally and confirmed eval metrics now appear correctly in wandb's evaluation panel. Would you like me to open a PR?
@Datta0 I have raised a PR and sharing a Colab Notebook to validate results. Attaching a screenshot of wandb panel and link to the workspace. Please take a look and let me know if any more info or changes are required
Here is the screenshot of the trainer after replacement