Skip to content

Conversation

@shenberg
Copy link
Contributor

My changes:

  1. SA lambda logic to pre-multiply the QKV weights by sa_lambda[0] instead of the v directly, as the QKV matrix is much smaller (between a factor of 10 and a factor of 30 depending on batch size), and we don't mind re: q, k as they are normed anyhow.
  2. fixed the warmup phase as it didn't set should_sync correctly on DistAdam.
  3. Added a gc.collect() before the start of training - I was getting unreliable training behavior without it. On a single-H100 run I enabled GC stats printouts (gc.set_debug(gc.DEBUG_STATS)) and saw that there was a gen-2 collection with ~0.5 seconds of stop-the-world (I also noticed that memory is leaking but that's for another time to solve).

I tried a couple of places that seemed like they should be better for the premultiply:

  1. multiplying the O matrix weights (with appropriate init change of sa_lambda[1] to 1.0 instead of 0.5)
  2. multiplying the attention output gates with the same rescale (for a factor 128 reduction in multiplications).

Unfortunately, both these options hurt the final loss. I don't know why...

@shenberg
Copy link
Contributor Author

shenberg commented Dec 10, 2025

I've attached the logs of the record, I can get it into a shape similar to the existing records if I get the LGTM from a maintainer.

image

logs_record.tar.gz

@varunneal
Copy link
Contributor

Premultiplication is such a cool trick. Regarding the O matrix: it is initialized to zero, so that probably has some serious impact on how the mult interacts with it. Nonzero inits also have been quite unstable (higher losses) so not sure we can get away with premul here. Same with all the gates.

@shenberg
Copy link
Contributor Author

shenberg commented Dec 10, 2025

I tried also changing the gate into a SiLU with the sa_lambda[0] as a bias (it should be initialized at ~0.25 in order to start out at the same spot as the current init) without luck (didn't think it should work as high sa_lambda values would lose gating sparsity, but maybe), and also doing a regular init for O with sa_lambda[0] starting at 0. None of the variations really worked.

Edit: Thank you for your efforts on this project and your blog on them, it's inspiring!

Edit 2: My reasoning was that because of the linearity of V operations relative to a multiplication by a scalar (we just take weighted-sums and multiply by gates), I really thought none of the versions except SiLU-gate should make a difference (e.g. in the backward pass sa_lambda gets its gradient through a multiplication by O^T anyhow). The only reason there should be a difference is because floating-point doesn't align with real numbers. I suspect it's flash-attention v3 that's causing the issues. I have an idea on how to test that out.

@ClassicLarry
Copy link
Collaborator

Glad you were able to get this working! Runs look good from my end. I'll merge at 131.6s (-0.6s current record) once logs are added.

@shenberg
Copy link
Contributor Author

So I did an experiment on the FAv3 hypothesis. I ran two versions:

  1. Based off of current master, changed line 866 from v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) to v = sa_lambdas[0] * (v + sa_lambdas[1] * ve.view_as(v)) and sa_lambda init from (0.5, 0.5) to (0.5, 1.0). loss in the same regime. Loss was 3.2777
  2. add the sa_lambda[0] multiplication to y instead of v. line 866 to v = v + sa_lambdas[1] * ve.view_as(v) and line 874 to y = sa_lambda[0] * y.view(B, T, self.num_heads, self.head_dim). Loss was 3.2774 😅

I ran one run each and both were easily below the loss threshold. There goes my theory...

Naturally, I did what any self-respecting "just-one-more-run-yolo" person does and moved the sa_lambda[0] multiply to the gate again. Loss was somewhat worse (3.2794)...

@shenberg
Copy link
Contributor Author

Glad you were able to get this working! Runs look good from my end. I'll merge at 131.6s (-0.6s current record) once logs are added.

Will do this evening, thanks a bunch for being this responsive!

@chrisjmccormick
Copy link
Contributor

@shenberg Have a look at the submission I'm preparing here, I think it explains / addresses the trouble you were having with pre-multiplication on the output projection.

NorMuon currently has a bug; it's intended to normalize variance along a single dimension, but the attention output heads are stored transposed relative to the QKV heads, so it's been getting misapplied.
The smear gate and attention gates have also been affected.

With that fixed, your intended pre-multiplication with the output matrix is giving me a speed up without reducing loss.

@chrisjmccormick
Copy link
Contributor

I just finished my runs and got the logs added. You can find my code inside them, or I've also updated the top-level train_gpt.py. I included your changes as part of my runs / used your code as the baseline.

I'm curious now whether there are other parts of the model which have been indirectly affected by the bug and may need some re-tuning?

@ClassicLarry
Copy link
Collaborator

Naturally, I did what any self-respecting "just-one-more-run-yolo" person does and moved the sa_lambda[0] multiply to the gate again. Loss was somewhat worse (3.2794)

I'd be cautious to interpret small loss jumps like that, well within the natural variation of random noise.

@shenberg Have a look at the submission I'm preparing here, I think it explains / addresses the trouble you were having with pre-multiplication on the output projection.

A lot of cool stuff in there. My hunch is Normuon should be dropped entirely for the gates and attn, and we can use a simple if 'mlp' in label then normuon(). Since the attn matrix is square, Muon is already giving us normalization coverage along both dimensions, and it operates on the attention out projection independently of qkv_projs so I dont think the orientation is a concern. For the smear and attention gate, if you swap the dimension, then Normuon is fixing the same dimension that Muon already normalized, so I doubt its worth including at all at that point. For instance, the smear update reduces to just multiplying by 1 since (shape=([1, 12])).mean(dim=-1) reduces to a single constant. In other words, you have found that Normuon was actively harmful for the smear_gate and replacing it with a dummy op of *1 is better.

@chrisjmccormick
Copy link
Contributor

To clarify, the orientation issue is not with Muon, but with NorMuon, which is "Neuron-wise Normalized Muon".

I agree that the orientation of $W^O$ does not matter for Muon, because it produces orthogonal matrix updates.

For NorMuon, the authors are addressing the underlying structure of the weight matrices, and scaling the weight updates on a per-neuron level with RMSNorm. From their paper:

while Muon’s orthogonalization effectively improves matrix-level conditioning, the per-neuron update norms exhibit high variance, with some neurons receiving disproportionately large updates relative to others.

So NorMuon operates along a specific matrix dimension (see also the code comment here), chosen so that we are applying RMSNorm to the neurons.

Even for a (1, 12) gate weight, the direction of the RMSNorm matters. Taking the sum of squares across the 12 values will produce a different scaling factor than treating each component individually (which is what the current implementation does, because of the over-simplified heuristic for choosing the dimension, here).

@chrisjmccormick
Copy link
Contributor

@shenberg @ClassicLarry, what do you think of the following for resolving the two record runs?

My submission is pretty "heavy" / cluttered--I was originally focused on a growing list of small optimizations to the NorMuon code until I uncovered the normalization issue, which is an ML update rather than a GPU one, and which ties to @shenberg's improvements.

What if, for the 12/10 record, we include:

  • From @shenberg:
    • The garbage collection and warmup fixes
    • The pre-multiplication technique applied to both QKV and O, now that it's working.
  • From myself:
    • Similar to the warmup fix, the use of torch.compile on DistAdam._sync_gradients was causing misbehavior with the should_sync flag, leading to some Adam weights being transferred on non-Adam steps.
    • The NorMuon fix, since it enables @shenberg's optimization to O (which provides a significant speed up!).

I could hold off on my GPU-related optimizations for a separate record soon. This will give me the opportunity to finish merging the two optimizers, which is the natural conclusion of the changes I was making.

What do you think?

@chrisjmccormick
Copy link
Contributor

chrisjmccormick commented Dec 11, 2025

I included this in my README, but I'll add it here since it's particularly relevant.

Edit: The first row is @shenberg's code run on my instance, the next includes their changes and mine, and the third includes all of that plus pre-multiplication of sa_lambdas[1] with $W^O$

Runs Time μ Time σ Time +/- Loss μ Loss σ Loss +/- p
12-10 Baseline from @shenberg 4 132.7217 0.1176 0.0000 3.2792 0.0008 0.0000 0.0632
NorMuon-Fix and Optims (mine) 13 131.4852 0.0680 -1.2365 3.2777 0.0016 -0.0015 0.0001
Mine + PreMul with W_O 10 131.2106 0.0660 -1.5111 3.2777 0.0020 -0.0015 0.0024

@ClassicLarry
Copy link
Collaborator

ClassicLarry commented Dec 11, 2025

@shenberg @ClassicLarry, what do you think of the following for resolving the two record runs?

My submission is pretty "heavy" / cluttered--I was originally focused on a growing list of small optimizations to the NorMuon code until I uncovered the normalization issue, which is an ML update rather than a GPU one, and which ties to @shenberg's improvements.

My pref is to keep the PRs separate by author to maximize individual credit. This one has also already been validated and is ready to merge once logs get added.

Even for a (1, 12) gate weight, the direction of the RMSNorm matters. Taking the sum of squares across the 12 values will produce a different scaling factor than treating each component individually (which is what the current implementation does, because of the over-simplified heuristic for choosing the dimension, here).

I agree that the dimension produces a different scaling factor. My point was that if your scaling factor is a single scalar after taking the norm, it looks redundant because you reverse out that scalar here. You might as well drop Normuon entirely at that point, because otherwise the code is putting in effort to multiply by s and then divide by s a couple lines later. I'll check this more definitively by adding print statements before and after normuon later once the PR is raised.

edit: I'll add that I'm fine to merge any record that improves runtime, but just a call-out that someone may claim another simple 200ms by dropping the extra NorMuon ops.

@shenberg
Copy link
Contributor Author

@shenberg Have a look at the submission I'm preparing here, I think it explains / addresses the trouble you were having with pre-multiplication on the output projection.

NorMuon currently has a bug; it's intended to normalize variance along a single dimension, but the attention output heads are stored transposed relative to the QKV heads, so it's been getting misapplied. The smear gate and attention gates have also been affected.

With that fixed, your intended pre-multiplication with the output matrix is giving me a speed up without reducing loss.

Cool stuff there! I was looking at a few of the same things, e.g. why is there a memcpy in the FP8 matrix multiply (turns out that the smart smart YouJiacheng knew that it's faster to reorder the matrix in a way that fits well with tensor cores). Also regarding the merged optimizer etc, I was thinking maybe to try implementing Collage since Adam is essentially bandwidth bound until compiled autograd works properly - even with an optimal reduce-scatter all-gather sequence so we're free to add more GPU-local compute to get better-quality weight updates.

@shenberg
Copy link
Contributor Author

@varunneal I noticed when looking at your most recent record that you missed one file in the scipy code (3e3c0f33-dc80-4efb-bb21-07af9b5161c9 (loss 3.2768)) and that you mistyped the first loss (it's actually 3.2777 and not 3.2772), not that it matters since the omitted file probably more than makes up for the reduced significance.

@chrisjmccormick
Copy link
Contributor

@ClassicLarry Ah, I see it now. When there's only one row vector, there's no population to normalize over, so it's just a no-op. And as you said earlier, the current implementation is harmful. It treats each position independently, and I think results in forcing the smear update to be either -1 or +1 for each weight? Bizarre.

I went to look at some runs I did a while back where I plotted visualizations of all of the scalars, and even the gate weights--but apparently not the smear gate weights! Now I'm curious to see what those look like.

I think there's still some principled motivation for using Muon on the smear gate over Adam, but it's certainly expensive. I'll try switching it and see if it hurts loss.

I can also test whether skipping the variance normalization for the attention matrices has negligible effect since they're square--skipping that step would be a small savings.

@chrisjmccormick
Copy link
Contributor

@shenberg That's a cool idea! Yeah, there's definitely plenty of compute window available for Adam, even in the current implementation:

Profiler trace of combined optimizer implementation

And a side note on the merged implementation--I just noticed that's how the authors implemented it originally, here:
https://github.com/zichongli5/NorMuon/blob/main/normuon.py

@chrisjmccormick
Copy link
Contributor

I think it's time for me to wrap this up for now, so I will go ahead and raise a pull request (dated today) later this afternoon.
I'll incorporate some insights from this discussion and be sure to credit @shenberg for the output pre-mul improvement.
Thanks!

@ClassicLarry ClassicLarry merged commit 960ad17 into KellerJordan:master Dec 11, 2025
@akash5474
Copy link
Contributor

akash5474 commented Dec 11, 2025

Nice catch on the DistAdam warmup @shenberg and congrats on the new record!

@chrisjmccormick you can try removing torch.compile from _sync_gradient, but I believe those reduce-scatters come from Normuon L529 which syncs on every step, not DistAdam. It is not aligned in the profiler trace visually, if you click one of the reduce scatter blocks it will show an arrow to the CPU op which launched the GPU kernel comes from Normuon. Alternatively in the details section that pops up when you click one of the blocks you can click the "slice name" link in the "preceding flows" section to get there.

image

The alignment threw me off quite a bit when I was working on this. You can count the number of reduce-scatters as well to confirm it, in the step where DistAdam syncs there is another set of reduce-scatters for Normuon following the all-gathers of DistAdam.

I don't think there is any difference in the profiler trace but I would happy to be shown otherwise and learn from this. I tried a lot of different variations myself so I'm really curious.

From what I found removing torch.compile made no difference and compiling hooks in the backward pass requires different settings entirely (i.e. enabling compiled autograd). It can probably be safely removed anyway, I didn't remove it because I wasn't sure if it was ok for my code to differ from what was in the logs (and I wanted to avoid spending money rerunning it 😅)

P.S. I really like your profiler trace images!

@shenberg
Copy link
Contributor Author

Nice catch on the DistAdam warmup @shenberg and congrats on the new record!

Thanks for your posts - I actually used your perfetto trace in the beginning to look for some optimization opportunities. I think maybe the reason compiled autograd is behaving poorly with regards to the hooks is because there's gradient accumulation happening? (I was puzzled for a bit as to why torch.backward() was doing more work on odd steps, I think that's the cause)

@shenberg shenberg deleted the lambda_weights branch December 12, 2025 09:34
@varunneal
Copy link
Contributor

Can you go over why we need to warmup the validation exactly?

@shenberg
Copy link
Contributor Author

Frankly, I don't think we have to. I was mainly trying to remove variance between runs because I was getting wacky stalling-step behavior and one of my suspicions was that TorchInductor was messing with the heap or CUDA memory pools, so I front-loaded everything in order to reduce the odds of validation steps influencing the training run. I don't think it's necessary, but it also probably adds very little total run-time to the run, though admittedly it's a worse experience to front-load all of the waiting.

I still don't understand those stalling steps - the behavior I saw with gc.set_debug(gc.DEBUG_STATS) was that every gen-0 GC collection (around once per step or two) would add around 50 objects to the gen-1 pool, which would eventually graduate to gen-2 (so basically leaking memory). I don't really understand how that's possible with the code looking as it does, I guess it's possible to dump the heap twice and find those objects. A task for another day.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants