Skip to content

Conversation

@akash5474
Copy link
Contributor

@akash5474 akash5474 commented Oct 31, 2025

Hello and happy Halloween! I think I've set a new record.

Adam Gradient Sync in Backward Hooks

This PR improves the overall training time and avg training step time by moving the DistAdam gradient sync reduce-scatter collectives for each model parameter out of the step method and into a backward hook. The step method is then modified to iterate through the param groups and parameters in reverse order to benefit from this change by stepping parameters in later layers first. The parameters of later layers will have their backward hooks called sooner, which should result in their gradient syncs being triggered earlier and completed sooner.

Timing and Validation

This PR improves the final training time by ~0.7 seconds

This PR:

import scipy.stats
import torch

losses = [3.2775, 3.2776, 3.2777, 3.2780, 3.2781, 3.2775, 3.2786, 3.2774, 3.2751, 3.2739]
times = [140.909, 140.872, 140.743, 140.743, 140.747, 140.809, 140.728, 140.784, 140.862, 140.934]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0001

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0015, mean=3.2771)

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0760, mean=140.8131)

Previous PR timed on same machine:

import scipy.stats
import torch

times = [141.654, 141.413, 141.467, 141.516]

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.1033, mean=141.5125)

Changes

Reduce Scatter in DistAdam Backward Hook

Even though the collective operations are async, they all occur at the end of each training step. The previous implementation looped through each parameter in order, launched the reduce-scatter operation, and then immediately waited for it to complete.

In this PR I moved the reduce-scatter operation launch out of the step method and into a backward hook, registered using register_post_accumulate_grad_hook to ensure the gradients are ready. Since the backwards hooks will first be executed for later model layers, their reduce-scatters will start and complete first.

Step params and param groups in reverse order

To take advantage of this, I modified the step method of DistAdam to iterate through the param_groups and parameters in reverse order. In our init function we define param groups to group by the parameter tensor's shape. The parameters of later layers are at the end of the parameters lists and since their reduce-scatters are launched earlier, we iterate and wait on the reduce-scatter futures earlier in our step method's loop.

Similarly we iterate through param_groups in reverse because the first group will correspond to the first shape we encounter, therefore parameters for later layers should be contained in later param_groups.

Profiler Trace Analysis

The tracefiles are checked in and can be explored using the perfetto trace viewer

Current Implementation

To start, here is the profiler trace for the current implementation. We can see that the first reduce-scatter operation begins at the start of the DistAdam step.

First Reduce-Scatter

profiler-trace-current-first-rs

Overview

profiler-trace-current-overview

Overlap

Looking at the GPU streams, we can see that the initial reduce-scatter does not overlap with the main GPU stream.

profiler-trace-current-comm-overlap

Hook Implementation

Similarly in the new implementation we can see that the first reduce-scatter is launched by the first hook.

First Reduce-Scatter

profiler-trace-hook-first-rs

Overview

profiler-trace-hook-overview

Overlap

This time we can see that the reduce-scatter overlaps with the computation on the main GPU stream.
profiler-trace-hook-comm-overlap

Comment on lines +721 to +723
@torch.compile
@torch.no_grad()
def _sync_gradient(self, param):
Copy link
Contributor Author

@akash5474 akash5474 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran my tests with this torch.compile decorator in place but realized afterwards that backward hooks are not supported, so I think this doesn't have any effect

@akash5474 akash5474 marked this pull request as ready for review October 31, 2025 23:27
@ClassicLarry
Copy link
Collaborator

This looks good. There are 2 PRs open before this that will get merged in first and may create a dependency on torch nightly 0926, so this will be retimed with those. I would expect the same impact but cannot guarantee. Since this does not modify the ML it does not require p value check.

Independent of this PR:
I suspect that GPU profiling is under-utilized in this challenge, and under-utilized in general. There are probably other low hanging fruit that use the same technique of overlapping activities. I see benefit in having a guide/article for GPU Profiling 101, Setup and applying to NanoGPT on 8H100. If you have an interest in writing something in that direction, I can add the link to the main readme on the row for this record.

@akash5474
Copy link
Contributor Author

akash5474 commented Nov 1, 2025

Thanks for taking a look and for the reply! I originally tested this on a different nightly version and saw similar results before doing my "official" testing, hopefully it still holds.

I would love to write up an article covering GPU profiling 101 on NanoGPT. I am actually working on a series of articles covering simplified implementations of distributed training algos and using the pytorch profiler to analyze and improve them. I only have one post currently on DDP (https://blog.underfit.ai/ddp-profiled) but tbh I want to rewrite and improve it. I'm wrapping up a post on ZeRO-1 and then have a draft on ZeRO-2 next. If this record holds I was planning to write a similar post about the experience, the many mistakes I made along the way, and the other things I tried changing which did not work.

The learning and exploring I did to write those posts is actually what got me here.

@akash5474
Copy link
Contributor Author

Hi @ClassicLarry, I wrote an intro to profiling post https://blog.underfit.ai/profiling-101-nanogpt. If you or anyone else has any feedback, or if you think there's anything more I should add, feel free to send me a message on X @underfitai (or reply here, though I'm not sure if this PR is the right place for that conversation).

@ClassicLarry
Copy link
Collaborator

Great, hoping to get this one merged next week.

reduce_scatter_futures: list[torch.Future] = []
world_size = dist.get_world_size()
all_gather_futures: list[torch.Future] = []
grad_slices = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am getting an error on line 715 because its expecting grad_slices to exist: g_slice = grad_slices[idx]. May be simplest if, along with updating this param, you resync this PR on-top of the now merged train.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the PR is likely missing a commit as it doesn't quite match the logs at

.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it further, the issue is that the previously merged PR #146 modified some of the structure of DistAdam(). This PR will need to be updated to account for that new structure.

Copy link
Contributor Author

@akash5474 akash5474 Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry! I made a mistake while resolving merge conflicts after #146 was merged. g_slice is defined on line 709, so line 715 just needed to be removed.

I did my test runs before #146 was merged so DistAdam changes were not reflected in the logs. I've re-run everything with the latest changes and updated the log files.

@ClassicLarry ClassicLarry merged commit 80d68af into KellerJordan:master Nov 16, 2025
@ClassicLarry
Copy link
Collaborator

Merged at 136.122, getting 0.9s speedup over last record.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants