Skip to content

Conversation

@joecummings
Copy link
Member

@joecummings joecummings commented Feb 24, 2025

What? This expands the use of StatefulDataloader to the following four recipes:

  • Full distributed
  • LoRA single device
  • LoRA distributed
  • GRPO

What did you change?
I copied pretty much the exact same changes as I did in #2410 into the above four recipes. I modified the tests for single device recipes b/c we are not relying on a different state. For the distributed recipes, I did not modify the tests at all. I had to change the hack we used to ensure the iterator finish even if we cut the epoch short b/c it was not robust and also looked ugly to manually modify the dataloader state dict.

Why these recipes?
These include our most stable recipes + our newest one GRPO. By looking at the changes I did here, users will be able to propagate changes to everything else in the library easily. I will be creating an Issue to track this for all the rest of the recipes.

How did you test GRPO?
Good question, b/c GRPO has no standardized tests in the torchtune library. It's also an interesting case b/c it does not follow the same format as our other recipes. For instance, there is NO option to cut an epoch short based on the number of steps in an individual epoch. (It's a weird variable anyways). However, there IS an option to reach a total number of training steps and stop. But wait, this wasn't included in the saved state dict!! There is no way to reasonably resume training for GRPO without actually implementing step-based checkpointing. This one is a wash right now :/

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2431

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 634be31 with merge base 7b654ea (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 24, 2025
@joecummings joecummings marked this pull request as ready for review February 26, 2025 16:52
@joecummings joecummings changed the title Add StatefulDataLoader to rest of recipes Add StatefulDataLoader to select other recipes Feb 26, 2025
@joecummings joecummings changed the title Add StatefulDataLoader to select other recipes Add StatefulDataLoader to select other recipes Feb 26, 2025
return sampler, dataloader
if dataloader_state_dict is not None:
dataloader.load_state_dict(dataloader_state_dict)
list(dataloader) # Hack to force dataloader to finish last iteration

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for including StatefulDataloader in these recipes.
I have a couple of comments.

  1. Let's add a more descriptive comment, something like "Since we break early with complete the last epoch and we want to start a new epoch when restarting train, we need to yield the remaining batches from the last epoch break breaking". Doesn't need to be this long, but something like this, as someone who doesn't work with StatefulDataLoader might get confused here.
  2. Do you think we need a flag (like finish_current_epoch or some other name) that we set when we break early from an epoch, and if that flag is True (we set it if we break early) only then we complete the epoch before restarting training.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1 - yeah definitely makes sense let me add it.

Re 2 - b/c right now we always save at epoch boundaries, there will never be a time when we don't want to complete the epoch before restarting training. This will change very very soon as we start mid-epoch checkpointing.

Copy link

@ramanishsingh ramanishsingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comment on finishing the DL epoch upon a restart.
LGTM!

@codecov-commenter
Copy link

codecov-commenter commented Feb 26, 2025

Codecov Report

Attention: Patch coverage is 0% with 41 lines in your changes missing coverage. Please review.

Project coverage is 65.37%. Comparing base (cf0142b) to head (634be31).
Report is 178 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 13 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 13 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 10 Missing ⚠️
tests/recipes/test_lora_finetune_single_device.py 0.00% 2 Missing ⚠️
...htune/training/checkpointing/_checkpoint_client.py 0.00% 2 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2431      +/-   ##
==========================================
- Coverage   65.38%   65.37%   -0.02%     
==========================================
  Files         374      374              
  Lines       22172    22189      +17     
==========================================
+ Hits        14498    14505       +7     
- Misses       7674     7684      +10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@joecummings joecummings merged commit 4d9840c into meta-pytorch:main Feb 26, 2025
17 checks passed
@joecummings joecummings deleted the scatter-stateful-dl branch February 26, 2025 19:39
joecummings added a commit to joecummings/torchtune that referenced this pull request Feb 27, 2025
joecummings added a commit to joecummings/torchtune that referenced this pull request Feb 27, 2025
pbontrager pushed a commit to pbontrager/torchtune that referenced this pull request Mar 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants