-
Notifications
You must be signed in to change notification settings - Fork 693
Add StatefulDataLoader to select other recipes
#2431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StatefulDataLoader to select other recipes
#2431
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2431
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 634be31 with merge base 7b654ea ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
StatefulDataLoader to select other recipes
| return sampler, dataloader | ||
| if dataloader_state_dict is not None: | ||
| dataloader.load_state_dict(dataloader_state_dict) | ||
| list(dataloader) # Hack to force dataloader to finish last iteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for including StatefulDataloader in these recipes.
I have a couple of comments.
- Let's add a more descriptive comment, something like "Since we break early with complete the last epoch and we want to start a new epoch when restarting train, we need to yield the remaining batches from the last epoch break breaking". Doesn't need to be this long, but something like this, as someone who doesn't work with StatefulDataLoader might get confused here.
- Do you think we need a flag (like
finish_current_epochor some other name) that we set when we break early from an epoch, and if that flag isTrue(we set it if we break early) only then we complete the epoch before restarting training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re 1 - yeah definitely makes sense let me add it.
Re 2 - b/c right now we always save at epoch boundaries, there will never be a time when we don't want to complete the epoch before restarting training. This will change very very soon as we start mid-epoch checkpointing.
ramanishsingh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the comment on finishing the DL epoch upon a restart.
LGTM!
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2431 +/- ##
==========================================
- Coverage 65.38% 65.37% -0.02%
==========================================
Files 374 374
Lines 22172 22189 +17
==========================================
+ Hits 14498 14505 +7
- Misses 7674 7684 +10 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What? This expands the use of
StatefulDataloaderto the following four recipes:What did you change?
I copied pretty much the exact same changes as I did in #2410 into the above four recipes. I modified the tests for single device recipes b/c we are not relying on a different state. For the distributed recipes, I did not modify the tests at all. I had to change the hack we used to ensure the iterator finish even if we cut the epoch short b/c it was not robust and also looked ugly to manually modify the dataloader state dict.
Why these recipes?
These include our most stable recipes + our newest one GRPO. By looking at the changes I did here, users will be able to propagate changes to everything else in the library easily. I will be creating an Issue to track this for all the rest of the recipes.
How did you test GRPO?
Good question, b/c GRPO has no standardized tests in the torchtune library. It's also an interesting case b/c it does not follow the same format as our other recipes. For instance, there is NO option to cut an epoch short based on the number of steps in an individual epoch. (It's a weird variable anyways). However, there IS an option to reach a total number of training steps and stop. But wait, this wasn't included in the saved state dict!! There is no way to reasonably resume training for GRPO without actually implementing step-based checkpointing. This one is a wash right now :/