Skip to content

Conversation

@varunneal
Copy link
Contributor

@varunneal varunneal commented Nov 10, 2025

Cautious Weight Decay

This record implements Cautious Weight Decay. It is replacing a previous PR PR#147.

Timing and Validation

This record improves the final training 40 steps, with a slight increase in step time.

This PR:

import scipy.stats
import torch

losses = [3.2784, 3.2771, 3.2777, 3.2790, 3.2794, 3.2813, 3.2772, 3.2772, 3.2785, 3.2783]
times = [137.582, 137.753, 137.636, 137.507, 137.639, 137.708, 137.722, 137.677, 137.456, 137.705]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0018

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0013, mean=3.2784)

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0970, mean=137.6385)

Previous PR (timed on same machine):

import scipy.stats
import torch

times = [139.813, 139.832, 139.877, 139.839, 139.939]

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0499, mean=139.8600)

These timings show an improvement of ~2.22 seconds.

Thank you to Prime Intellect for sponsoring my research.

"Cautious" weight decay

I found that weight decay leads to stable training dynamics, but performance seems to suffer. I stumbled upon the paper Cautious Weight Decay which proposes only applying weight decay on the parameters that are growing in magnitude, and this proved to be effective.

Based on suggestion from @classiclaryd, I kept weight decay on a schedule. After trying various combinations, I found that the same schedule as learning rate is quite good, so I kept the previous calculation of effective_weight_decay = learning_rate * weight_decay. Scheduled weight decay improves performance on CWD by 10-15 steps.

The choice of wd=1.2 is well tuned. In practice, it actually corresponds to starting effective_weight_decay = 1.2 x 0.03 = 0.036.

Cautious weight decay might be better called "masked decoupled weight decay". While it should be an unbiased estimator, I noticed that this weight decay has a very different training dynamic than the baseline:

val-loss

In particular, we find that CWD has higher validation loss for the majority of the training run. There is an inflection point when the learning rate decreases, and CWD only "catches up" to the baseline in the final steps of training. I noticed this dynamic irrespective of whether WD is placed on a schedule.

Parameters under CWD have mean square magnitude <20% of the magnitude under the baseline. I found this pattern consistently for both MLP and ATTN parameters.

I found that the condition number after CWD is virtually identical the the condition number after NorMuon:

cond-numbers

I believe this PR opens the door for rich future work, including tuning the WD schedule and CWD for Adam.

@ClassicLarry
Copy link
Collaborator

Slight delay on the validations here because it looks like slightly increasing the total runtime of the Muon Optimizer is bringing back the hiccupping on some hardware (where 1-3 steps throughout the run will take 400ms instead of 60ms). Looking to see if there is a robust way we can handle this that removes the dependency on small hardware variance, which will require deep dive into what is causing the 400ms stall.

@varunneal
Copy link
Contributor Author

@ClassicLarry Just a couple things that might help (don't solve the underlying issue but might mitigate)

  • run with OMP_NUM_THREADS=8 e.g. `OMP_NUM_THREADS=8 torchrun --standalone --nproc_per_node=8 train_gpt.py
  • try torch.compile(dynamic=False) on Muon's step instead of on polar_express

@ClassicLarry ClassicLarry merged commit 29aefcf into KellerJordan:master Nov 14, 2025
@ClassicLarry
Copy link
Collaborator

Merged at 137.022. Stall did not occur on latest test, so I'm merging this now to make it easier for people who are testing to not have to parse together a bunch of PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants