|
| 1 | +# New record 09/29/25 |
| 2 | + |
| 3 | +This submission reflects all recent WR changes up to [PR#133](https://github.com/KellerJordan/modded-nanogpt/pull/133). |
| 4 | + |
| 5 | +The main improvement in this PR is using the [Polar Express](https://arxiv.org/pdf/2505.16932) |
| 6 | +sign method in Muon instead of Newton-Schulz. This paper was designed with reference to ModdedNanoGPT so it was very easy to implement, |
| 7 | +and I direct the reader to this paper directly for details. Using Polar Express, I've reduced the train steps by 10. |
| 8 | + |
| 9 | +The next change in this PR is packaging Flash Attention 3 via [Huggingface's Kernels](https://huggingface.co/docs/kernels/en/index). |
| 10 | +This does not impact timing but should increase ease of development for anyone working on this project. |
| 11 | + |
| 12 | +## Timing and Validation |
| 13 | + |
| 14 | +This PR improves the final training by 10 steps, with no change in the time per step. |
| 15 | + |
| 16 | +``` |
| 17 | +import scipy.stats |
| 18 | +import torch |
| 19 | +
|
| 20 | +losses = [3.2789, 3.2792, 3.2796, 3.2776, 3.2797, 3.2787, 3.2792] |
| 21 | +times = [148.617, 148.580, 148.569, 148.653, 148.578, 148.542, 148.587] |
| 22 | +
|
| 23 | +print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue) |
| 24 | +# p=0.0045 |
| 25 | +
|
| 26 | +print("losses:", torch.std_mean(torch.tensor(losses))) |
| 27 | +# losses: (std=0.0007057, mean=3.2789857) |
| 28 | +
|
| 29 | +print("time:", torch.std_mean(torch.tensor(times))) |
| 30 | +# time: (std=0.0358076, mean=148.5894318) |
| 31 | +``` |
| 32 | + |
| 33 | +You may notice that this PR shows a 0.2 second mean *increase* in timing over the result in PR#133. |
| 34 | +However, that PR was timed on very fast machine. To demonstrate that this PR accurately represents |
| 35 | +a decrease in train time, I timed PR#133 on the same machine as above: |
| 36 | + |
| 37 | +``` |
| 38 | +import scipy.stats |
| 39 | +import torch |
| 40 | +
|
| 41 | +times = [149.714, 149.676, 149.659, 149.716, 149.649, 149.569, 149.521] |
| 42 | +
|
| 43 | +print("time:", torch.std_mean(torch.tensor(times))) |
| 44 | +# time: (std=0.0732, mean=149.6434) |
| 45 | +``` |
| 46 | + |
| 47 | +Therefore, I believe that this PR represents at least a 1 second improvement. |
| 48 | + |
| 49 | +Thank you to Prime Intellect for compute credits, which made this PR possible. |
| 50 | + |
| 51 | +## Polar Express |
| 52 | + |
| 53 | +All credit to the original authors (Noah Amsel, David Persson, Christopher Musco, Robert M. Gower) |
| 54 | +for discovery and implementation of this method. I adapted their code from https://github.com/NoahAmsel/PolarExpress/tree/main. |
| 55 | + |
| 56 | +I found optimal parameters with |
| 57 | +- `num_iters=5`: each iteration adds about a second to train time |
| 58 | +- `muon_lr=0.06`: I found bumping the Muon LR seems to perform slightly better |
| 59 | +- `safety_factor=1.02`: hyperparameter for Polar Express coefficients |
| 60 | + |
| 61 | +Despite the paper explicitly referencing and showing improvements on Modded NanoGPT, |
| 62 | +I was unable to replicate the level of success shown in this paper. However, it may |
| 63 | +be possible to further tune parameters to achieve a better result. |
| 64 | +Additionally, like [Cesista 2025](https://leloykun.github.io/ponder/muon-opt-coeffs/) I believe it may be more promising on the GPT Medium track. |
| 65 | + |
| 66 | +## Flash Attention 3 Huggingface Kernel |
| 67 | + |
| 68 | +A couple weeks ago, Flash Attention merged [ABI-stability](https://github.com/Dao-AILab/flash-attention/pull/1791) |
| 69 | +into the main FA3 repo. This allows builds of Flash Attention on PyTorch nightlies after 08/30 to be compatible with each other. |
| 70 | +Since [PR#118](https://github.com/KellerJordan/modded-nanogpt/pull/118), we have been using |
| 71 | +[a variant](https://github.com/Dao-AILab/flash-attention/pull/1769) of FA3 by @Guilhermeleobas that is compatible with `torch.compile`. |
| 72 | +I have written a [fork](https://github.com/varunneal/flash-attention/tree/stable) that combines |
| 73 | +these changes and uploaded its build to Huggingface at https://huggingface.co/varunneal/flash-attention-3. |
| 74 | + |
| 75 | +I have modified training script to fetch these builds via Hugginface's `get_kernel`. |
| 76 | +Therefore, it will no longer be needed for developers to manually build Flash Attention. |
| 77 | + |
| 78 | +I have packaged this kernel for both CUDA 12.6 and 12.8 for the following PyTorch versions: |
| 79 | +- `2.8.0` |
| 80 | +- `2.9` nightlies after 8/30 |
| 81 | +- `2.10` nightlies |
| 82 | + |
| 83 | +Note that the actual build `.so` is identical for all Torch Nightly versions. |
| 84 | + |
| 85 | +This most recent record uses the same `2.10` nightly as PR#133. |
0 commit comments