Skip to content

Conversation

@thib-s
Copy link

@thib-s thib-s commented Nov 11, 2025

Description:

Building atop #154 this PR introduces Turbo-Muon, a preconditioned version of the Newton–Schulz orthogonalization used in Muon. The method applies an Almost-Orthogonal Layer (AOL) preconditioning step that improves the initial approximation of the polar factor, enabling faster convergence and the removal of one NS iteration without loss of numerical precision. ( as described here and here )

Timing and Validation

Unfortunately, I do not have access to nodes with 8x H100 GPUs, so my verification was done on nodes with 4x H100 GPUs.
I then adjusted the runtimes to have a coarse estimation of the runtime on 8xH100 (likely too optimistic). I've added my measurements for PR154 on similar hardware (4xH100) for reference.

import scipy.stats
import torch

losses = [3.2779,3.2782,3.2782,3.2788,3.2814,3.2807,3.2780,3.2780,3.2795,3.2802]
# I adjusted times since I have only 4xH100 instead of 8
# times should not be compared directly, more modest improvements should be expected
times = [266.623/2,266.600/2,266.504/2,266.707/2,266.779/2,266.627/2,266.532/2,266.682/2,266.816/2,266.726/2]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0018

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0013, mean=3.2791)
# results of PR154: losses: (std=0.0014, mean=3.2774)


print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0506, mean=133.3298)
# results of PR154: time: (std=0.0734, mean=136.8784)

Two important notes:

  1. The p-value seems correct, but I've noted a very slight increase in mean loss (+0.002)
  2. My experiments run on a Slurm cluster, and I've observed variations in runtime depending on the allocated node (around 1.5s). I'm confident this PR improves runtime, but I do think that the gain is overestimated (I expect something closer to 1.5s)

Changes

For this implementation we started from the dion implementation of newton schulz
which has a great triton implementation of the newton schulz algorithm.

triton kernel for ns_line_3:

We noticed that the ns_line_3 function was taking a lot of time, so we wrote a triton kernel to avoid multiple
loadings of the same data. This give a marginal speedup on small matrices, where loading data is the bottleneck.

Fewer iterations:

We remove the previous normalization to switch to AOL rescaling
Which is further explained in the paper: https://arxiv.org/pdf/2208.03160

This consists in computing W@W^t using ns_line_1 and then computing the
scaling factors: fast_inv_sqrt(reduce_sum(abs(WW^t), axis=-1)) which is a vector

polar_error_filtered_pure_AOL

Since the main operation to compute those correspond to ns_line_1,
we can fuse it with the first newton schulz iterate. Furthermore this gives a better
starting point for the newton schulz iterations as the matrix is closer to orthogonal

Thanks to this, we can save one iteration of newton schulz:

polar_error_filtered_annotated (Muon+ refers to an implementation similar to the one used in this repository)

However, usual polynomial coefficients are not optimal anymore since the preconditioning changes the spectrum of the matrix. This allows for more aggressive coefficients, which would be suboptimal without preconditioning. The simplest way to obtain such coefficients is to simply drop the first iteration of the usual Newton-Schulz coefficients.

Finally, this can be applied out of the box in the medium track; where more significant runtime improvements are to be expected.

@thib-s thib-s marked this pull request as draft November 11, 2025 19:31
@varunneal
Copy link
Contributor

This is terrific! Though I am surprised you opted to remove the first set of coefficients instead of the last set of coefficients. I'd imagine that removing the last iteration's coefficients is much more accurate for convergence. For reference, here are the polar_express coefficients for num_iters=4

[
  (8.156554524902461, -22.48329292557795, 15.878769915207462),
  (4.042929935166739, -2.808917465908714, 0.5000178451051316),
  (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
  (3.351468730910768, -2.5130779633710483, 0.5128347598303034)
]

By the way, you might get further improvement when tuning the learning rate (and decreasing total iterations correspondingly). This result between the relationship of the strength of the pre-conditioning and the learning rate seems to be very important.

@thib-s
Copy link
Author

thib-s commented Nov 12, 2025

That surprised me too! In a sense, preconditioning does a similar job as the first Newton-Schulz iteration:

I initially tried to recompute factors from the polar express, but I could not find any clever way to incorporate the knowledge about my matrix: notably, I measured maximum singular values much closer to 1 (around 0.75) and minimum singular values further from 0 (around 0.02-0.05).
I tried to tune $l=2\times 10^{-2}$ instead of $l=2\times 10^{-2}$ but could not measure significant improvements on five runs, so I stuck to the truncation of parameters, which was surprisingly effective. (Maybe playing with the cushion and safety factor could help ?)

filename val_loss_mean val_loss_std
train_gpt_PR149 3.27744 0.001555
coef trunc 3.27996 0.001880
polar retuning 3.27959 0.001620

(I was planning to update #149 and saw that #154 was more recent, so those early experiments were done against PR149)

Interesting link, at first I removed one iteration in order to match a similar polar error and avoid re-tuning hyperparameters, but it might also be feasible to keep the fifth iteration and adjust lr (which might save a few training steps instead of making faster steps)

image

(Muon+ refers to Newton-Schulz with optimized polynomials)

@thib-s thib-s marked this pull request as ready for review November 12, 2025 15:39
@thib-s thib-s changed the title Preconditioned orthogonalization for faster Muon optimizer New WR: Preconditioned orthogonalization for faster Muon optimizer Nov 12, 2025
@varunneal
Copy link
Contributor

varunneal commented Nov 12, 2025

@thib-s Do you have a plot on polar error when removing the first versus last iteration? It makes sense that preconditioning is doing the work as the first step of NS iteration, but my understanding is orthogonalization is independent wrt the conditioning of the underlying input (though I'm over my head here a bit) never mind its definitely not independent; this is sort of clear if you consider why Polar Express has different coeffs every iteration:

Moreover, after each step of the algorithm, the range of the singular values changes; therefore, we adapt the update rule at each iteration to match the new interval. When the range of the singular values is large, this approach ensures that the update rule shrinks it as quickly as possible. As the algorithm proceeds and the interval shrinks to a small neighborhood of 1, the update rule approaches that of a Pade method, maintaining the same high order of convergence as it has.

It makes sense why removing the first set of coeffs is better once we apply the AOL preconditioning

@ClassicLarry
Copy link
Collaborator

Planning to get to this one after cautious weight decay and 149. The loss increase probably corresponds to >1s on its own, so depending on how the communication bottleneck plays out on 8 GPUs I think there's a chance that only the triton kernel for ns_line_3 will be worthwhile on the small track. But I've never ran on 4H100 so don't have much intuition on how it translates.

tyler-romero added a commit to tyler-romero/modded-nanogpt that referenced this pull request Nov 13, 2025
@ClassicLarry
Copy link
Collaborator

I am getting an error on:
p.copy_(unstacked_params[i], non_blocking=True)
[rank1]: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Somewhat odd, the error only shows up somewhere between step 3 and 18.

Also from the extra indent on the first line.

PrimeIntellect often has 8H100 spot instances for $9/hr, which may be helpful for identifying the issue.

@thib-s
Copy link
Author

thib-s commented Nov 17, 2025

I'm surprised to see this error located in an apparently unrelated location of the code ( multiple operations are applied to v_chunk between the call of the added Triton kernel L773 and the raising of an error L825 )
Maybe this issue comes from a different behavior in torch.compile? I made a back-to-back comparison with the logs in PR154 and noted slight changes in versions:

Logs in PR155:

Running Python 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
Running PyTorch 2.10.0.dev20251019+cu126 compiled for CUDA 12.6
Running Triton version 3.5.0
Wed Nov 12 09:47:44 2025     

Logs in PR154:

Running Python 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
Running PyTorch 2.10.0.dev20250926+cu126 compiled for CUDA 12.6
Running Triton version 3.5.0
Mon Nov 10 21:53:41 2025   

PrimeIntellect often has 8H100 spot instances for $9/hr, which may be helpful for identifying the issue.

I'll take a look at those, and also check if the increase form 3.2774 to 3.2791 also occurs on 8xH100 hardware.

@ClassicLarry
Copy link
Collaborator

I was testing on torch nightly 926. I have seen odd behavior before where if the data exceeds the range of the data type, it bugs out only on the line that is materializing the output in the torch call. Which makes me think that triton kernel is computing a gradient that is numerically not compatible, but I haven’t actually looked at the kernel.

@ClassicLarry
Copy link
Collaborator

I did a bit of testing and it looks like the current record has so much cushion in the loss below 3.28 that it’s giving a misleading impression on what is an improvement. A lot of changes give faster run times while staying under 3.28, like dropping 20 steps (only checked 4 runs so can’t 100% confirm). So I am in favor of reducing this cushion as long as run count for p value stays around 10 or so. You can consume the cushion by dropping NS iteration, reducing step count, or many other ways. No restrictions but there is a possibility that a future PR will find a more efficient way of consuming that cushion.

@ClassicLarry
Copy link
Collaborator

ClassicLarry commented Nov 17, 2025

A lot of changes give faster run times while staying under 3.28, like dropping 20 steps (only checked 4 runs so can’t 100% confirm).

After discussing with Varun I think there's a chance my couple tests here were atypically lucky. So take this with a grain of salt.

@ClassicLarry ClassicLarry marked this pull request as draft December 9, 2025 01:18
@ClassicLarry
Copy link
Collaborator

Marking as draft until we have a successful run on 8H100, to reduce confusion when people are trying to build on top of the latest record.

@thib-s
Copy link
Author

thib-s commented Dec 10, 2025

agreed, I'll take this opportunity to rebase the code onto the latest WR.

@ClassicLarry
Copy link
Collaborator

agreed, I'll take this opportunity to rebase the code onto the latest WR.

Cool, there is an open PR that you may want to build on top of as well, which will get merged in first if it’s timing holds.

@sozforex
Copy link

@thib-s This may be of some use to you to obtain possibly better Polar express coefficients for fewer iterations:

To obtain ~ the same coefficients as in the v3 of the Polar Express paper, one can call this function https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L41 with these parameters:

optimal_composition(l=1e-3, num_iters=8, safety_factor_eps=0, cushion=0.02407327424182761)

I'm not sure if different values for safety_factor_eps and cusion would be better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants