Skip to content

Commit 3f43c27

Browse files
authored
Merge pull request #134 from varunneal/polar-express
New WR: Polar express (-10 steps) + packaged Flash Attention quality of life improvement
2 parents fab427a + 6504fee commit 3f43c27

9 files changed

+22888
-24
lines changed

records/092925_PolarExpress/0e3f0af5-ad08-47a6-813d-0c709b50d422.txt

Lines changed: 3252 additions & 0 deletions
Large diffs are not rendered by default.

records/092925_PolarExpress/16ae9716-24a6-4b5f-ad2e-ce0986903334.txt

Lines changed: 3252 additions & 0 deletions
Large diffs are not rendered by default.

records/092925_PolarExpress/188c5c21-a850-4b45-ab17-d168a5bec7e7.txt

Lines changed: 3252 additions & 0 deletions
Large diffs are not rendered by default.

records/092925_PolarExpress/3bb6c2eb-1935-46d5-9f07-40b98223cfaa.txt

Lines changed: 3252 additions & 0 deletions
Large diffs are not rendered by default.

records/092925_PolarExpress/730671d8-2fca-498a-819a-0bdf0f3aa76c.txt

Lines changed: 3252 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# New record 09/29/25
2+
3+
This submission reflects all recent WR changes up to [PR#133](https://github.com/KellerJordan/modded-nanogpt/pull/133).
4+
5+
The main improvement in this PR is using the [Polar Express](https://arxiv.org/pdf/2505.16932)
6+
sign method in Muon instead of Newton-Schulz. This paper was designed with reference to ModdedNanoGPT so it was very easy to implement,
7+
and I direct the reader to this paper directly for details. Using Polar Express, I've reduced the train steps by 10.
8+
9+
The next change in this PR is packaging Flash Attention 3 via [Huggingface's Kernels](https://huggingface.co/docs/kernels/en/index).
10+
This does not impact timing but should increase ease of development for anyone working on this project.
11+
12+
## Timing and Validation
13+
14+
This PR improves the final training by 10 steps, with no change in the time per step.
15+
16+
```
17+
import scipy.stats
18+
import torch
19+
20+
losses = [3.2789, 3.2792, 3.2796, 3.2776, 3.2797, 3.2787, 3.2792]
21+
times = [148.617, 148.580, 148.569, 148.653, 148.578, 148.542, 148.587]
22+
23+
print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
24+
# p=0.0045
25+
26+
print("losses:", torch.std_mean(torch.tensor(losses)))
27+
# losses: (std=0.0007057, mean=3.2789857)
28+
29+
print("time:", torch.std_mean(torch.tensor(times)))
30+
# time: (std=0.0358076, mean=148.5894318)
31+
```
32+
33+
You may notice that this PR shows a 0.2 second mean *increase* in timing over the result in PR#133.
34+
However, that PR was timed on very fast machine. To demonstrate that this PR accurately represents
35+
a decrease in train time, I timed PR#133 on the same machine as above:
36+
37+
```
38+
import scipy.stats
39+
import torch
40+
41+
times = [149.714, 149.676, 149.659, 149.716, 149.649, 149.569, 149.521]
42+
43+
print("time:", torch.std_mean(torch.tensor(times)))
44+
# time: (std=0.0732, mean=149.6434)
45+
```
46+
47+
Therefore, I believe that this PR represents at least a 1 second improvement.
48+
49+
Thank you to Prime Intellect for compute credits, which made this PR possible.
50+
51+
## Polar Express
52+
53+
All credit to the original authors (Noah Amsel, David Persson, Christopher Musco, Robert M. Gower)
54+
for discovery and implementation of this method. I adapted their code from https://github.com/NoahAmsel/PolarExpress/tree/main.
55+
56+
I found optimal parameters with
57+
- `num_iters=5`: each iteration adds about a second to train time
58+
- `muon_lr=0.06`: I found bumping the Muon LR seems to perform slightly better
59+
- `safety_factor=1.02`: hyperparameter for Polar Express coefficients
60+
61+
Despite the paper explicitly referencing and showing improvements on Modded NanoGPT,
62+
I was unable to replicate the level of success shown in this paper. However, it may
63+
be possible to further tune parameters to achieve a better result.
64+
Additionally, like [Cesista 2025](https://leloykun.github.io/ponder/muon-opt-coeffs/) I believe it may be more promising on the GPT Medium track.
65+
66+
## Flash Attention 3 Huggingface Kernel
67+
68+
A couple weeks ago, Flash Attention merged [ABI-stability](https://github.com/Dao-AILab/flash-attention/pull/1791)
69+
into the main FA3 repo. This allows builds of Flash Attention on PyTorch nightlies after 08/30 to be compatible with each other.
70+
Since [PR#118](https://github.com/KellerJordan/modded-nanogpt/pull/118), we have been using
71+
[a variant](https://github.com/Dao-AILab/flash-attention/pull/1769) of FA3 by @Guilhermeleobas that is compatible with `torch.compile`.
72+
I have written a [fork](https://github.com/varunneal/flash-attention/tree/stable) that combines
73+
these changes and uploaded its build to Huggingface at https://huggingface.co/varunneal/flash-attention-3.
74+
75+
I have modified training script to fetch these builds via Hugginface's `get_kernel`.
76+
Therefore, it will no longer be needed for developers to manually build Flash Attention.
77+
78+
I have packaged this kernel for both CUDA 12.6 and 12.8 for the following PyTorch versions:
79+
- `2.8.0`
80+
- `2.9` nightlies after 8/30
81+
- `2.10` nightlies
82+
83+
Note that the actual build `.so` is identical for all Torch Nightly versions.
84+
85+
This most recent record uses the same `2.10` nightly as PR#133.

0 commit comments

Comments
 (0)