New WR: Polar express (-10 steps) + packaged Flash Attention quality of life improvement #134
+22,888
−24
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This submission reflects all recent WR changes up to PR#133.
The main improvement in this PR is using the Polar Express sign method in Muon instead of Newton-Schulz. This paper was designed with reference to ModdedNanoGPT so it was very easy to implement, and I direct the reader to this paper directly for details. Using Polar Express, I've reduced the train steps by 10.
The next change in this PR is packaging Flash Attention 3 via Huggingface's Kernels. This does not impact timing but should increase ease of development for anyone working on this project.
Timing and Validation
This PR improves the final training by 10 steps, with no change in the time per step.
You may notice that this PR shows a 0.2 second mean increase in timing over the result in PR#133. However, that PR was timed on very fast machine. To demonstrate that this PR accurately represents a decrease in train time, I timed PR#133 on the same machine as above:
Therefore, I believe that this PR represents around a 1 second improvement.
Thank you to Prime Intellect for compute credits, which made this PR possible.
Polar Express
All credit to the original authors (Noah Amsel, David Persson, Christopher Musco, Robert M. Gower) for discovery and implementation of this method. I adapted their code from https://github.com/NoahAmsel/PolarExpress/tree/main.
I found optimal parameters with
num_iters=5: each iteration adds about a second to train timemuon_lr=0.06: I found bumping the Muon LR seems to perform slightly bettersafety_factor=1.02: hyperparameter for Polar Express coefficientsDespite the paper explicitly referencing and showing improvements on Modded NanoGPT, I was unable to replicate the level of success shown in this paper. However, it may be possible to further tune parameters to achieve a better result.
Additionally, like Cesista 2025 I believe it may be more promising on the GPT Medium track.
Flash Attention 3 Huggingface Kernel
A couple weeks ago, Flash Attention merged ABI-stability into the main FA3 repo. This allows builds of Flash Attention on PyTorch nightlies after 08/30 to be compatible with each other. Since PR#118, we have been using a variant of FA3 by @guilhermeleobas that is compatible with
torch.compile. I have written a fork that combines these changes and uploaded its build to Huggingface at https://huggingface.co/varunneal/flash-attention-3.I have modified the training script to fetch these builds via Hugginface's
get_kernel. Therefore, it is no longer required to manually build Flash Attention.I have packaged this kernel for both CUDA 12.6 and 12.8 for the following PyTorch versions:
2.8.02.9nightlies after 8/302.10nightliesNote that the actual build
.sofile is identical for all Torch Nightly versions.This most recent record uses the same 2.10 nightly as PR#133.