Skip to content

Comments

Fix model training state restoration in GRPO trainer#3754

Merged
danielhanchen merged 7 commits intounslothai:mainfrom
numb3r33:main
Jan 5, 2026
Merged

Fix model training state restoration in GRPO trainer#3754
danielhanchen merged 7 commits intounslothai:mainfrom
numb3r33:main

Conversation

@numb3r33
Copy link
Contributor

@numb3r33 numb3r33 commented Dec 20, 2025

Fixes #3744

@danielhanchen @shimmyshimmer

I did some debugging and found that during evaluation, Trainer.evaluate() correctly calls model.eval() to set model.training=False. However, Unsloth's patched _generate_and_score_completions function calls self.model.for_training() (rl_replacements.py#L267, L273) but never restores the original mode afterward. This leaves model.training=True even during evaluation.

When TRL's log() method runs, it checks self.model.training to determine which metrics bucket to use. Since model.training=True, eval metrics get logged to the "train" bucket instead of "eval".

Proposed Fix:

Save the original training state (_was_training = self.model.training) before calling for_training() (around L260)
Restore it at the end of _generate_and_score_completions before return output.

Example:

Modify the existing replacement_lines to save the mode

replacement_lines = """
    batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
    _was_training = self.model.training
    try:
        # TRL 0.23.1 and below path
        if not has_images:
            # Left pad prompt before calculation old and ref hidden states
            prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
        self.model.for_training()
    except:
        # TRL 0.24.0 and below path
        if images is None:
            # Left pad prompt before calculation old and ref hidden states
            prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
        self.model.for_training()"""

Add this new replacement near the end of the function (before return function):

function = function.replace(
    "        return output",  # 8 spaces before 'return'
    """        if not _was_training:
            self.model.for_inference()
        return output"""
)

I've tested this fix locally and confirmed eval metrics now appear correctly in wandb's evaluation panel. Would you like me to open a PR?

Image

@Datta0 I have raised a PR and sharing a Colab Notebook to validate results. Attaching a screenshot of wandb panel and link to the workspace. Please take a look and let me know if any more info or changes are required

Image

Here is the screenshot of the trainer after replacement

image

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @numb3r33, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an issue within the GRPO trainer where the model's training state was not consistently restored after certain operations. The changes ensure that the model's original training or inference mode is preserved and correctly reinstated, preventing potential side effects on subsequent training or evaluation steps and maintaining predictable model behavior.

Highlights

  • Model State Preservation: Introduced a mechanism to capture the model's initial training state (self.model.training) before executing the generate_and_score_completions function.
  • Inference Mode Restoration: Ensured that if the model was not initially in training mode, it is explicitly set back to inference mode (self.model.for_inference()) after the completion generation and scoring process, preventing unintended state changes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the GRPO trainer where the model's training state was not correctly restored. The changes involve saving the model's training state before it's modified and then restoring it before the function returns. This ensures the model is in the correct mode (training or inference) after completions are generated and scored. The fix is logical and well-implemented. I have one suggestion to improve the maintainability of the code patching by using a more robust method.

Comment on lines 392 to 397
function = function.replace(
" return output",
""" if not _was_training:
self.model.for_inference()
return output""",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using function.replace with a hardcoded indentation for return output is brittle. If the indentation of this line changes in a future version of trl, this patch will silently fail. It's more robust to use re.sub to capture the existing indentation and apply it to the replacement code. This approach is more maintainable and consistent with other patching logic in this file.

Suggested change
function = function.replace(
" return output",
""" if not _was_training:
self.model.for_inference()
return output""",
)
restoration_code = r'''if not _was_training:
self.model.for_inference()
return output'''
function = re.sub(
r'^(\s*)return output$',
lambda m: textwrap.indent(restoration_code, m.group(1)),
function,
flags=re.MULTILINE,
)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 392 to 396
function = function.replace(
" return output",
""" if not _was_training:
self.model.for_inference()
return output""",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard restoring inference when _was_training is absent

The new replacement injects a if not _was_training: self.model.for_inference() block but it is applied unconditionally to every _generate_and_score_completions body. _was_training is only introduced when the earlier line_to_replace substitution succeeds (lines ~260-274); if the upstream TRL implementation changes and that exact batch_size line is not matched, the return statement will still be rewritten to reference _was_training, leading to a NameError when the trainer exits instead of simply skipping the restoration. Consider checking that the batch_size replacement succeeded or guarding the use of _was_training so future TRL versions don’t crash.

Useful? React with 👍 / 👎.


function = function.replace(
" return output",
""" if not _was_training:
Copy link
Collaborator

@Datta0 Datta0 Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the indentation here appropriate? Can you please show the final generated code in UnslothGRPOTrainer.py (a screenshot perhaps would help, add it to the PR description)

Also please add your investigation and the summary from the other comment to here in PR desc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @Datta0 i have updated the PR description with additional information and screenshot of the generated code, please have a look and let me know if any info is needed. Thanks again for reviewing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thanks for that.
also if we can dynamically infer the indent length instead of hard coding? If you can do that, it'd be great

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure @Datta0 makes sense, I have modified the logic to use re

match = re.search(r"^(\s*)return output", function, re.MULTILINE)

if match:
    indent = match.group(1)
    new_code = indent + "if not _was_training:\n" + indent + "    self.model.for_inference()\n" + indent + "return output"
    function = function.replace(f"{indent}return output", new_code)

Please have a look.

numb3r303 and others added 6 commits December 24, 2025 01:34
Store the model's training state before generation and restore inference
mode after completion if the model wasn't originally in training mode.
This ensures the model returns to the correct state after generate and
score operations.
Use regex to dynamically detect and preserve the original indentation
when replacing the 'return output' statement, instead of hardcoding
spaces. This ensures the patched code maintains consistent indentation
regardless of the original formatting.
Replace f-string triple-quoted approach with explicit newline characters
for clearer string construction in the grpo_trainer patch.
@Datta0
Copy link
Collaborator

Datta0 commented Dec 24, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the GRPO trainer where the model's training state was not correctly restored after evaluation, leading to incorrect metric logging. The fix involves saving the model's training state at the beginning of _generate_and_score_completions and restoring it before the function returns. The implementation is correct and effectively solves the issue. I've added one suggestion to improve the robustness of the code patching by using re.sub instead of str.replace to prevent potential unintended side effects.

+ indent
+ "return output"
)
function = function.replace(f"{indent}return output", new_code)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using str.replace() can be risky here as it will replace all occurrences of f"{indent}return output". If this string appears multiple times in the function's source code, it could lead to unintended modifications. It's safer to use re.sub() with count=1 to ensure only the first match (which is the final return statement you're targeting) is replaced.

Suggested change
function = function.replace(f"{indent}return output", new_code)
function = re.sub(r"^\s*return output", new_code, function, count=1, flags=re.MULTILINE)

Copy link
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@numb3r33
Copy link
Contributor Author

numb3r33 commented Jan 5, 2026

@Datta0 happy new year, want to check if this is good to merge and is there anything required from my side?

@Datta0
Copy link
Collaborator

Datta0 commented Jan 5, 2026

Hey @numb3r303 I think this is good. Just waiting for @danielhanchen to get some time to merge :)
Might happen in a day or two

danielhanchen added a commit that referenced this pull request Jan 5, 2026
GRPO: restore model mode after generate (stacked on #3754)
@danielhanchen danielhanchen merged commit ef3e2b3 into unslothai:main Jan 5, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Evaluation metrics logged to training in wandb

4 participants