-
Notifications
You must be signed in to change notification settings - Fork 532
New WR: -0.5 seconds using small change in multiplication order for self-attention lambda, and fixed warmup for DistAdam #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Premultiplication is such a cool trick. Regarding the O matrix: it is initialized to zero, so that probably has some serious impact on how the mult interacts with it. Nonzero inits also have been quite unstable (higher losses) so not sure we can get away with premul here. Same with all the gates. |
|
I tried also changing the gate into a SiLU with the Edit: Thank you for your efforts on this project and your blog on them, it's inspiring! Edit 2: My reasoning was that because of the linearity of V operations relative to a multiplication by a scalar (we just take weighted-sums and multiply by gates), I really thought none of the versions except SiLU-gate should make a difference (e.g. in the backward pass sa_lambda gets its gradient through a multiplication by O^T anyhow). The only reason there should be a difference is because floating-point doesn't align with real numbers. I suspect it's flash-attention v3 that's causing the issues. I have an idea on how to test that out. |
|
Glad you were able to get this working! Runs look good from my end. I'll merge at 131.6s (-0.6s current record) once logs are added. |
|
So I did an experiment on the FAv3 hypothesis. I ran two versions:
I ran one run each and both were easily below the loss threshold. There goes my theory... Naturally, I did what any self-respecting "just-one-more-run-yolo" person does and moved the sa_lambda[0] multiply to the gate again. Loss was somewhat worse (3.2794)... |
Will do this evening, thanks a bunch for being this responsive! |
|
@shenberg Have a look at the submission I'm preparing here, I think it explains / addresses the trouble you were having with pre-multiplication on the output projection. NorMuon currently has a bug; it's intended to normalize variance along a single dimension, but the attention output heads are stored transposed relative to the QKV heads, so it's been getting misapplied. With that fixed, your intended pre-multiplication with the output matrix is giving me a speed up without reducing loss. |
|
I just finished my runs and got the logs added. You can find my code inside them, or I've also updated the top-level train_gpt.py. I included your changes as part of my runs / used your code as the baseline. I'm curious now whether there are other parts of the model which have been indirectly affected by the bug and may need some re-tuning? |
I'd be cautious to interpret small loss jumps like that, well within the natural variation of random noise.
A lot of cool stuff in there. My hunch is Normuon should be dropped entirely for the gates and attn, and we can use a simple |
|
To clarify, the orientation issue is not with Muon, but with NorMuon, which is "Neuron-wise Normalized Muon". I agree that the orientation of For NorMuon, the authors are addressing the underlying structure of the weight matrices, and scaling the weight updates on a per-neuron level with RMSNorm. From their paper:
So NorMuon operates along a specific matrix dimension (see also the code comment here), chosen so that we are applying RMSNorm to the neurons. Even for a (1, 12) gate weight, the direction of the RMSNorm matters. Taking the sum of squares across the 12 values will produce a different scaling factor than treating each component individually (which is what the current implementation does, because of the over-simplified heuristic for choosing the dimension, here). |
|
@shenberg @ClassicLarry, what do you think of the following for resolving the two record runs? My submission is pretty "heavy" / cluttered--I was originally focused on a growing list of small optimizations to the NorMuon code until I uncovered the normalization issue, which is an ML update rather than a GPU one, and which ties to @shenberg's improvements. What if, for the 12/10 record, we include:
I could hold off on my GPU-related optimizations for a separate record soon. This will give me the opportunity to finish merging the two optimizers, which is the natural conclusion of the changes I was making. What do you think? |
|
I included this in my README, but I'll add it here since it's particularly relevant. Edit: The first row is @shenberg's code run on my instance, the next includes their changes and mine, and the third includes all of that plus pre-multiplication of
|
My pref is to keep the PRs separate by author to maximize individual credit. This one has also already been validated and is ready to merge once logs get added.
I agree that the dimension produces a different scaling factor. My point was that if your scaling factor is a single scalar after taking the norm, it looks redundant because you reverse out that scalar here. You might as well drop Normuon entirely at that point, because otherwise the code is putting in effort to multiply by s and then divide by s a couple lines later. I'll check this more definitively by adding print statements before and after normuon later once the PR is raised. edit: I'll add that I'm fine to merge any record that improves runtime, but just a call-out that someone may claim another simple 200ms by dropping the extra NorMuon ops. |
Cool stuff there! I was looking at a few of the same things, e.g. why is there a memcpy in the FP8 matrix multiply (turns out that the smart smart YouJiacheng knew that it's faster to reorder the matrix in a way that fits well with tensor cores). Also regarding the merged optimizer etc, I was thinking maybe to try implementing Collage since Adam is essentially bandwidth bound until compiled autograd works properly - even with an optimal reduce-scatter all-gather sequence so we're free to add more GPU-local compute to get better-quality weight updates. |
|
@varunneal I noticed when looking at your most recent record that you missed one file in the scipy code (3e3c0f33-dc80-4efb-bb21-07af9b5161c9 (loss 3.2768)) and that you mistyped the first loss (it's actually 3.2777 and not 3.2772), not that it matters since the omitted file probably more than makes up for the reduced significance. |
|
@ClassicLarry Ah, I see it now. When there's only one row vector, there's no population to normalize over, so it's just a no-op. And as you said earlier, the current implementation is harmful. It treats each position independently, and I think results in forcing the smear update to be either -1 or +1 for each weight? Bizarre. I went to look at some runs I did a while back where I plotted visualizations of all of the scalars, and even the gate weights--but apparently not the smear gate weights! Now I'm curious to see what those look like. I think there's still some principled motivation for using Muon on the smear gate over Adam, but it's certainly expensive. I'll try switching it and see if it hurts loss. I can also test whether skipping the variance normalization for the attention matrices has negligible effect since they're square--skipping that step would be a small savings. |
|
@shenberg That's a cool idea! Yeah, there's definitely plenty of compute window available for Adam, even in the current implementation: And a side note on the merged implementation--I just noticed that's how the authors implemented it originally, here: |
e89cb13 to
42f0a1f
Compare
|
I think it's time for me to wrap this up for now, so I will go ahead and raise a pull request (dated today) later this afternoon. |
|
Nice catch on the DistAdam warmup @shenberg and congrats on the new record! @chrisjmccormick you can try removing torch.compile from _sync_gradient, but I believe those reduce-scatters come from Normuon L529 which syncs on every step, not DistAdam. It is not aligned in the profiler trace visually, if you click one of the reduce scatter blocks it will show an arrow to the CPU op which launched the GPU kernel comes from Normuon. Alternatively in the details section that pops up when you click one of the blocks you can click the "slice name" link in the "preceding flows" section to get there.
The alignment threw me off quite a bit when I was working on this. You can count the number of reduce-scatters as well to confirm it, in the step where DistAdam syncs there is another set of reduce-scatters for Normuon following the all-gathers of DistAdam. I don't think there is any difference in the profiler trace but I would happy to be shown otherwise and learn from this. I tried a lot of different variations myself so I'm really curious. From what I found removing P.S. I really like your profiler trace images! |
Thanks for your posts - I actually used your perfetto trace in the beginning to look for some optimization opportunities. I think maybe the reason compiled autograd is behaving poorly with regards to the hooks is because there's gradient accumulation happening? (I was puzzled for a bit as to why torch.backward() was doing more work on odd steps, I think that's the cause) |
|
Can you go over why we need to warmup the validation exactly? |
|
Frankly, I don't think we have to. I was mainly trying to remove variance between runs because I was getting wacky stalling-step behavior and one of my suspicions was that TorchInductor was messing with the heap or CUDA memory pools, so I front-loaded everything in order to reduce the odds of validation steps influencing the training run. I don't think it's necessary, but it also probably adds very little total run-time to the run, though admittedly it's a worse experience to front-load all of the waiting. I still don't understand those stalling steps - the behavior I saw with |


My changes:
sa_lambda[0]instead of thevdirectly, as the QKV matrix is much smaller (between a factor of 10 and a factor of 30 depending on batch size), and we don't mind re: q, k as they are normed anyhow.gc.collect()before the start of training - I was getting unreliable training behavior without it. On a single-H100 run I enabled GC stats printouts (gc.set_debug(gc.DEBUG_STATS)) and saw that there was a gen-2 collection with ~0.5 seconds of stop-the-world (I also noticed that memory is leaking but that's for another time to solve).I tried a couple of places that seemed like they should be better for the premultiply:
sa_lambda[1]to 1.0 instead of 0.5)Unfortunately, both these options hurt the final loss. I don't know why...