-
Notifications
You must be signed in to change notification settings - Fork 532
New WR: Adam Gradient Sync in Backward Hooks #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New WR: Adam Gradient Sync in Backward Hooks #149
Conversation
Refactor gradient synchronization to use futures for reduce scatter.
| @torch.compile | ||
| @torch.no_grad() | ||
| def _sync_gradient(self, param): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran my tests with this torch.compile decorator in place but realized afterwards that backward hooks are not supported, so I think this doesn't have any effect
|
This looks good. There are 2 PRs open before this that will get merged in first and may create a dependency on torch nightly 0926, so this will be retimed with those. I would expect the same impact but cannot guarantee. Since this does not modify the ML it does not require p value check. Independent of this PR: |
|
Thanks for taking a look and for the reply! I originally tested this on a different nightly version and saw similar results before doing my "official" testing, hopefully it still holds. I would love to write up an article covering GPU profiling 101 on NanoGPT. I am actually working on a series of articles covering simplified implementations of distributed training algos and using the pytorch profiler to analyze and improve them. I only have one post currently on DDP (https://blog.underfit.ai/ddp-profiled) but tbh I want to rewrite and improve it. I'm wrapping up a post on ZeRO-1 and then have a draft on ZeRO-2 next. If this record holds I was planning to write a similar post about the experience, the many mistakes I made along the way, and the other things I tried changing which did not work. The learning and exploring I did to write those posts is actually what got me here. |
|
Hi @ClassicLarry, I wrote an intro to profiling post https://blog.underfit.ai/profiling-101-nanogpt. If you or anyone else has any feedback, or if you think there's anything more I should add, feel free to send me a message on X @underfitai (or reply here, though I'm not sure if this PR is the right place for that conversation). |
|
Great, hoping to get this one merged next week. |
| reduce_scatter_futures: list[torch.Future] = [] | ||
| world_size = dist.get_world_size() | ||
| all_gather_futures: list[torch.Future] = [] | ||
| grad_slices = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am getting an error on line 715 because its expecting grad_slices to exist: g_slice = grad_slices[idx]. May be simplest if, along with updating this param, you resync this PR on-top of the now merged train.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the PR is likely missing a commit as it doesn't quite match the logs at
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at it further, the issue is that the previously merged PR #146 modified some of the structure of DistAdam(). This PR will need to be updated to account for that new structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry! I made a mistake while resolving merge conflicts after #146 was merged. g_slice is defined on line 709, so line 715 just needed to be removed.
I did my test runs before #146 was merged so DistAdam changes were not reflected in the logs. I've re-run everything with the latest changes and updated the log files.
|
Merged at 136.122, getting 0.9s speedup over last record. |
Hello and happy Halloween! I think I've set a new record.
Adam Gradient Sync in Backward Hooks
This PR improves the overall training time and avg training step time by moving the
DistAdamgradient sync reduce-scatter collectives for each model parameter out of thestepmethod and into a backward hook. Thestepmethod is then modified to iterate through the param groups and parameters in reverse order to benefit from this change by stepping parameters in later layers first. The parameters of later layers will have their backward hooks called sooner, which should result in their gradient syncs being triggered earlier and completed sooner.Timing and Validation
This PR improves the final training time by ~0.7 seconds
This PR:
Previous PR timed on same machine:
Changes
Reduce Scatter in
DistAdamBackward HookEven though the collective operations are async, they all occur at the end of each training step. The previous implementation looped through each parameter in order, launched the reduce-scatter operation, and then immediately waited for it to complete.
In this PR I moved the reduce-scatter operation launch out of the
stepmethod and into a backward hook, registered usingregister_post_accumulate_grad_hookto ensure the gradients are ready. Since the backwards hooks will first be executed for later model layers, their reduce-scatters will start and complete first.Step params and param groups in reverse order
To take advantage of this, I modified the
stepmethod ofDistAdamto iterate through the param_groups and parameters in reverse order. In our init function we define param groups to group by the parameter tensor's shape. The parameters of later layers are at the end of the parameters lists and since their reduce-scatters are launched earlier, we iterate and wait on the reduce-scatter futures earlier in ourstepmethod's loop.Similarly we iterate through param_groups in reverse because the first group will correspond to the first shape we encounter, therefore parameters for later layers should be contained in later param_groups.
Profiler Trace Analysis
The tracefiles are checked in and can be explored using the perfetto trace viewer
Current Implementation
To start, here is the profiler trace for the current implementation. We can see that the first reduce-scatter operation begins at the start of the
DistAdamstep.First Reduce-Scatter
Overview
Overlap
Looking at the GPU streams, we can see that the initial reduce-scatter does not overlap with the main GPU stream.
Hook Implementation
Similarly in the new implementation we can see that the first reduce-scatter is launched by the first hook.
First Reduce-Scatter
Overview
Overlap
This time we can see that the reduce-scatter overlaps with the computation on the main GPU stream.
