Skip to content

Conversation

@dominikkallusky
Copy link

@dominikkallusky dominikkallusky commented Sep 16, 2025

Summary

This PR builds on PR#124 and adds the Snoo optimizer (Sparse Nesterov Outer Optimizer) which improves the medium WR by 60 steps, (~10s).
Snoo is a look-ahead momentum-based wrapper that can improve the quality of large language models (LLM) and other models. Snoo implicitly smoothens the training trajectory and instills a bias towards flatter minima. Snoo is computationally efficient, incurring minimal overhead in compute and moderate memory usage.
@dominikkallusky, @vishal9-team, @vinaysrao

Code

@torch.no_grad()
def step(
    self,
) -> None:
    if self.current_step % self.k == 0:
        for p_new, p_old in zip(self.model_params, self.outer_buf):
            p_new.grad = p_old.data - p_new.data
            p_new.copy_(p_old, non_blocking=True)
        self.optimizer.step()
        for p_new, p_old in zip(self.model_params, self.outer_buf):
            p_old.copy_(p_new, non_blocking=True)
    self.current_step += 1

Stats:

Train time Val Loss
1425323 2.9200130000
1399753 2.9192430000
1400871 2.9196530000
1432667 2.9202780000
1400392 2.9209690000
1399248 2.9187820000
1399035 2.9204980000
1399037 2.9198950000
1399258 2.9191080000
1399369 2.9205680000
1398911 2.9203090000
1399043 2.9192930000
1398410 2.9185420000
1398703 2.9196380000
1399135 2.9197590000
1399177 2.9176600000
1398529 2.9196670000
1398326 2.9213670000
1398652 2.9195370000
1399108 2.9179940000
1398671 2.9200720000
1398786 2.9207020000
1423478 2.9210300000
1398958 2.9192110000
1398334 2.9189310000
1431761 2.9182240000
1398236 2.9214560000
1398511 2.9195970000
1397775 2.9194410000
1398159 2.9199210000
1398501 2.9211470000
1398141 2.9178630000
1397903 2.9186550000
1399010 2.9195980000
1397790 2.9190140000
1398512 2.9209280000
1398639 2.9202180000
1398457 2.9196910000
1398093 2.9195350000
1397701 2.9185550000
1397840 2.9196240000
1397735 2.9193710000

Count: 42
Train Time:

  • Mean: 1401522.33
  • Std: 8906.82
  • Min: 1397701
  • Max: 1432667

Val Loss:

  • Mean: 2.919656
  • Std: 0.0009364946890159557
  • Min: 2.91766
  • Max: 2.921456
  • P_Val(<2.92): 0.011

@ClassicLarry
Copy link
Collaborator

This is cool. I know little on optimizers. Looking at the code, I'm reading this as, every 28 steps, compute the distance traveled over those steps, undo it, and then move in that direction more smoothly with nesterov+momentum+SGD.

The p value is not below the 0.01 requirement. Adding back some of the 60 steps might be worthwhile here. Otherwise the challenge ends up in a state where nobody can contribute without burning money on 80+ runs because the record is so close to the cutoff that its impossible to get p<0.01. Hard for me to tell exactly how much of an improvement this is since prior record had mean of 2.9191 and this one has mean of 2.919656- I'd estimate there is some improvement but not quite 60 steps worth.

@dominikkallusky
Copy link
Author

dominikkallusky commented Sep 17, 2025

Ok yea that's fair. I reran with 5640 steps and now the p-value is below the 0.01 requirement.
I'm wondering if it might make sense to mandate torch.use_deterministic_algorithms(True) at least for the medium track where reruns are expensive.

This is cool. I know little on optimizers. Looking at the code, I'm reading this as, every 28 steps, compute the distance traveled over those steps, undo it, and then move in that direction more smoothly with nesterov+momentum+SGD.

Yes, basically.

print(df_nanogpt_med_5640[['loss', 'train_time']].reset_index(drop=True))
print(f"{df_nanogpt_med_5640['train_time'].mean()=}")
print(f"{df_nanogpt_med_5640['train_time'].std()=}")
print(f"{df_nanogpt_med_5640['train_time'].count()=}")
print(f"{df_nanogpt_med_5640['train_time'].min()=}")
print(f"{df_nanogpt_med_5640['train_time'].max()=}")
print(f"{df_nanogpt_med_5640['loss'].mean()=}")
print(f"{df_nanogpt_med_5640['loss'].min()=}")
print(f"{df_nanogpt_med_5640['loss'].max()=}")
print(f"{df_nanogpt_med_5640['loss'].std()=}")
print(f"{scipy.stats.ttest_1samp(df_nanogpt_med_5640['loss'].to_numpy().tolist(), 2.92, alternative='less').pvalue=}")
Train time Val Loss
2.920866 1404429
2.919848 1405527
2.920058 1405527
2.919486 1404794
2.919221 1404749
2.919295 1403139
2.920905 1402907
2.919472 1403525
2.919485 1410229
2.918337 1403318
2.918905 1403348
2.921223 1403409
2.919060 1403128
2.919060 1405181
2.919242 1405181
2.918973 1403238
2.919746 1403648
2.919541 1403085
2.919301 1403406
2.919458 1402836
2.919502 1402963
2.917969 1403876
2.920135 1402608
2.920262 1404008
2.919758 1403351
2.920201 1403465
2.919559 1403964
2.920641 1402720
2.919203 1406409
2.919769 1402930

df_nanogpt_med_5640['train_time'].mean()=np.float64(1404029.9333333333)
df_nanogpt_med_5640['train_time'].std()=1523.797296985019
df_nanogpt_med_5640['train_time'].count()=np.int64(30)
df_nanogpt_med_5640['train_time'].min()=1402608.0
df_nanogpt_med_5640['train_time'].max()=1410229.0
df_nanogpt_med_5640['loss'].mean()=np.float64(2.919616033333334)
df_nanogpt_med_5640['loss'].min()=2.917969
df_nanogpt_med_5640['loss'].max()=2.921223
df_nanogpt_med_5640['loss'].std()=0.0007156313095892154
scipy.stats.ttest_1samp(df_nanogpt_med_5640['loss'].to_numpy().tolist(), 2.92, alternative='less').pvalue=np.float64(0.0032015979421488247)

@dominikkallusky dominikkallusky changed the title New medium track WR: 1401s. Snoo Optimizer. Includes #124 and #119 New medium track WR: 1404s. Snoo Optimizer. Includes #124 and #119 Sep 17, 2025
@ClassicLarry
Copy link
Collaborator

Awesome! 50 steps is big.

I'm wondering if it might make sense to mandate torch.use_deterministic_algorithms(True)

At one point I tried testing with this because I figured it would be way easier to assess changes, but unfortunately the runtime was substantially worse and the loss curve followed a different trajectory iirc. Made it so a change could be good under deterministic algos but bad under stochastic. As a result, I never went back to testing with deterministic algos. I think the issue was primarily because of the bfloat16 params and fp8 lm_head on the short track. In addition, I think deterministic algos creates a risk people will curve fit every parameter to the validation set. There is already some amount of curve fitting but stochastic algos helps mitigate that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants