-
Notifications
You must be signed in to change notification settings - Fork 532
New WR 148.3s: Compute cross entropy in BF16 during training #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New WR 148.3s: Compute cross entropy in BF16 during training #133
Conversation
|
Would love to remove the very last entry with Result without the very last entry: p=0.0005
losses: (tensor(0.0015), tensor(3.2788))
time: (tensor(0.1799), tensor(148.3504)) |
|
Awesome, great improvement! Looks like loss is still below 0.01 so no issues. I see you are using PyTorch 2.10.0.dev20250926+cu126. Either your GPU's happen to run slightly faster or there is a small boost from the pytorch version as well. I think I will remove iteration_extension from a future PR when I have a meaningful improvement to compensate because, even though it seems to improves mean loss, it also seems to add more variance to runs which makes testing other updates more challenging. |
I got lucky with GPUs on PrimeIntellect! I started multiple instances and selected the fastest based on the test run. |
|
This PR does not include the prior record txt or readme files. The pull request log is starting to get rather long, and probably very hard to follow for people seeing the repo for the first time. @KellerJordan is there any plan to perform merges or add maintainers? Do you have any thoughts on a community-supported branch getting spun up that is actively maintained? |
It won't be a problem once previous PRs are merged |
Good to know. The next record may have a higher runtime and just need to benchmark that its a faster time relative to this one when controlling for GPUs. 148.3 is crazy. |
|
I'm having some issues building Flash Attention on this version of Torch Nightly. Did you encounter any issues? |
Try following instructions from #118 And yes, I had a few additional problems, sorry that I forgot to mention that.
|
This PR builds on all recent improvements, up to #132
Removed a
.float()cast before loss so training keeps logits in BF16 all the way intoF.cross_entropy. Validation still casts logits to FP32 to prevent BF16 rounding noise in the reported CE and keep results comparable with prior runs.Commit with
train_gpt.pychange: 346f4cbValidation for #132:
Validation for this WR: