-
Notifications
You must be signed in to change notification settings - Fork 532
New WR: Preconditioned orthogonalization for faster Muon optimizer #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
This is terrific! Though I am surprised you opted to remove the first set of coefficients instead of the last set of coefficients. I'd imagine that removing the last iteration's coefficients is much more accurate for convergence. For reference, here are the polar_express coefficients for By the way, you might get further improvement when tuning the learning rate (and decreasing total iterations correspondingly). This result between the relationship of the strength of the pre-conditioning and the learning rate seems to be very important. |
|
That surprised me too! In a sense, preconditioning does a similar job as the first Newton-Schulz iteration: I initially tried to recompute factors from the polar express, but I could not find any clever way to incorporate the knowledge about my matrix: notably, I measured maximum singular values much closer to 1 (around 0.75) and minimum singular values further from 0 (around 0.02-0.05).
(I was planning to update #149 and saw that #154 was more recent, so those early experiments were done against PR149) Interesting link, at first I removed one iteration in order to match a similar polar error and avoid re-tuning hyperparameters, but it might also be feasible to keep the fifth iteration and adjust lr (which might save a few training steps instead of making faster steps)
(Muon+ refers to Newton-Schulz with optimized polynomials) |
Co-authored-by: massena-t <[email protected]>
|
@thib-s Do you have a plot on polar error when removing the first versus last iteration? It makes sense that preconditioning is doing the work as the first step of NS iteration, but
It makes sense why removing the first set of coeffs is better once we apply the AOL preconditioning |
|
Planning to get to this one after cautious weight decay and 149. The loss increase probably corresponds to >1s on its own, so depending on how the communication bottleneck plays out on 8 GPUs I think there's a chance that only the triton kernel for ns_line_3 will be worthwhile on the small track. But I've never ran on 4H100 so don't have much intuition on how it translates. |
|
I am getting an error on: Also from the extra indent on the first line. PrimeIntellect often has 8H100 spot instances for $9/hr, which may be helpful for identifying the issue. |
|
I'm surprised to see this error located in an apparently unrelated location of the code ( multiple operations are applied to v_chunk between the call of the added Triton kernel L773 and the raising of an error L825 ) Logs in PR155: Logs in PR154:
I'll take a look at those, and also check if the increase form 3.2774 to 3.2791 also occurs on 8xH100 hardware. |
|
I was testing on torch nightly 926. I have seen odd behavior before where if the data exceeds the range of the data type, it bugs out only on the line that is materializing the output in the torch call. Which makes me think that triton kernel is computing a gradient that is numerically not compatible, but I haven’t actually looked at the kernel. |
|
I did a bit of testing and it looks like the current record has so much cushion in the loss below 3.28 that it’s giving a misleading impression on what is an improvement. A lot of changes give faster run times while staying under 3.28, like dropping 20 steps (only checked 4 runs so can’t 100% confirm). So I am in favor of reducing this cushion as long as run count for p value stays around 10 or so. You can consume the cushion by dropping NS iteration, reducing step count, or many other ways. No restrictions but there is a possibility that a future PR will find a more efficient way of consuming that cushion. |
After discussing with Varun I think there's a chance my couple tests here were atypically lucky. So take this with a grain of salt. |
Co-authored-by: massena-t <[email protected]>
Co-authored-by: massena-t <[email protected]>
|
Marking as draft until we have a successful run on 8H100, to reduce confusion when people are trying to build on top of the latest record. |
|
agreed, I'll take this opportunity to rebase the code onto the latest WR. |
Cool, there is an open PR that you may want to build on top of as well, which will get merged in first if it’s timing holds. |
# Conflicts: # train_gpt.py
|
@thib-s This may be of some use to you to obtain possibly better Polar express coefficients for fewer iterations: To obtain ~ the same coefficients as in the v3 of the Polar Express paper, one can call this function https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L41 with these parameters: I'm not sure if different values for |

Description:
Building atop #154 this PR introduces Turbo-Muon, a preconditioned version of the Newton–Schulz orthogonalization used in Muon. The method applies an Almost-Orthogonal Layer (AOL) preconditioning step that improves the initial approximation of the polar factor, enabling faster convergence and the removal of one NS iteration without loss of numerical precision. ( as described here and here )
Timing and Validation
Unfortunately, I do not have access to nodes with 8x H100 GPUs, so my verification was done on nodes with 4x H100 GPUs.
I then adjusted the runtimes to have a coarse estimation of the runtime on 8xH100 (likely too optimistic). I've added my measurements for PR154 on similar hardware (4xH100) for reference.
Two important notes:
Changes
For this implementation we started from the dion implementation of newton schulz
which has a great triton implementation of the newton schulz algorithm.
triton kernel for ns_line_3:
We noticed that the ns_line_3 function was taking a lot of time, so we wrote a triton kernel to avoid multiple
loadings of the same data. This give a marginal speedup on small matrices, where loading data is the bottleneck.
Fewer iterations:
We remove the previous normalization to switch to AOL rescaling
Which is further explained in the paper: https://arxiv.org/pdf/2208.03160
This consists in computing W@W^t using ns_line_1 and then computing the
scaling factors: fast_inv_sqrt(reduce_sum(abs(WW^t), axis=-1)) which is a vector
Since the main operation to compute those correspond to ns_line_1,
we can fuse it with the first newton schulz iterate. Furthermore this gives a better
starting point for the newton schulz iterations as the matrix is closer to orthogonal
Thanks to this, we can save one iteration of newton schulz:
However, usual polynomial coefficients are not optimal anymore since the preconditioning changes the spectrum of the matrix. This allows for more aggressive coefficients, which would be suboptimal without preconditioning. The simplest way to obtain such coefficients is to simply drop the first iteration of the usual Newton-Schulz coefficients.
Finally, this can be applied out of the box in the medium track; where more significant runtime improvements are to be expected.