Skip to content

Conversation

@varunneal
Copy link
Contributor

@varunneal varunneal commented Sep 29, 2025

This submission reflects all recent WR changes up to PR#133.

The main improvement in this PR is using the Polar Express sign method in Muon instead of Newton-Schulz. This paper was designed with reference to ModdedNanoGPT so it was very easy to implement, and I direct the reader to this paper directly for details. Using Polar Express, I've reduced the train steps by 10.

The next change in this PR is packaging Flash Attention 3 via Huggingface's Kernels. This does not impact timing but should increase ease of development for anyone working on this project.

Timing and Validation

This PR improves the final training by 10 steps, with no change in the time per step.

import scipy.stats
import torch

losses = [3.2789, 3.2792, 3.2796, 3.2776, 3.2797, 3.2787, 3.2792]
times = [148.617, 148.580, 148.569, 148.653, 148.578, 148.542, 148.587]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0045

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0007057, mean=3.2789857)

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0358076, mean=148.5894318)

You may notice that this PR shows a 0.2 second mean increase in timing over the result in PR#133. However, that PR was timed on very fast machine. To demonstrate that this PR accurately represents a decrease in train time, I timed PR#133 on the same machine as above:

import scipy.stats
import torch

times = [149.714, 149.676, 149.659, 149.716, 149.649, 149.569, 149.521]

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0732, mean=149.6434)

Therefore, I believe that this PR represents around a 1 second improvement.

Thank you to Prime Intellect for compute credits, which made this PR possible.

Polar Express

All credit to the original authors (Noah Amsel, David Persson, Christopher Musco, Robert M. Gower) for discovery and implementation of this method. I adapted their code from https://github.com/NoahAmsel/PolarExpress/tree/main.

I found optimal parameters with

  • num_iters=5: each iteration adds about a second to train time
  • muon_lr=0.06: I found bumping the Muon LR seems to perform slightly better
  • safety_factor=1.02: hyperparameter for Polar Express coefficients

Despite the paper explicitly referencing and showing improvements on Modded NanoGPT, I was unable to replicate the level of success shown in this paper. However, it may be possible to further tune parameters to achieve a better result.

Additionally, like Cesista 2025 I believe it may be more promising on the GPT Medium track.

Flash Attention 3 Huggingface Kernel

A couple weeks ago, Flash Attention merged ABI-stability into the main FA3 repo. This allows builds of Flash Attention on PyTorch nightlies after 08/30 to be compatible with each other. Since PR#118, we have been using a variant of FA3 by @guilhermeleobas that is compatible with torch.compile. I have written a fork that combines these changes and uploaded its build to Huggingface at https://huggingface.co/varunneal/flash-attention-3.

I have modified the training script to fetch these builds via Hugginface's get_kernel. Therefore, it is no longer required to manually build Flash Attention.

I have packaged this kernel for both CUDA 12.6 and 12.8 for the following PyTorch versions:

  • 2.8.0
  • 2.9 nightlies after 8/30
  • 2.10 nightlies

Note that the actual build .so file is identical for all Torch Nightly versions.

This most recent record uses the same 2.10 nightly as PR#133.

@Gusarich
Copy link
Contributor

Incredible! Did you try other combinations of parameters?

@varunneal
Copy link
Contributor Author

varunneal commented Sep 29, 2025

@Gusarich which parameters? e.g. which ones to put onto Muon? Or the hyperparameters for Polar Express? I tried a bit for the latter -- the changes from the original paper I found were that

  • a safety_factor of 1.02 seemed better than 1.01
  • Taking 5 iterations was better than 6 iterations -- convergence was similar but step time was lower for the former
  • I noticed a Muon learning rate of 0.06 was slightly better than the previous LR of 0.05. LR may be able to be tuned further via Polar Express

@sozforex
Copy link

Just in case, to obtain ~ the same coefficients as in the v3 of the Polar Express paper, one can call this function https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L41 with these parameters:

optimal_composition(l=1e-3, num_iters=8, safety_factor_eps=0, cushion=0.02407327424182761)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants