New WR: Cautious weight decay (-40 steps) #154
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Cautious Weight Decay
This record implements Cautious Weight Decay. It is replacing a previous PR PR#147.
Timing and Validation
This record improves the final training 40 steps, with a slight increase in step time.
This PR:
Previous PR (timed on same machine):
These timings show an improvement of ~2.22 seconds.
Thank you to Prime Intellect for sponsoring my research.
"Cautious" weight decay
I found that weight decay leads to stable training dynamics, but performance seems to suffer. I stumbled upon the paper Cautious Weight Decay which proposes only applying weight decay on the parameters that are growing in magnitude, and this proved to be effective.
Based on suggestion from @classiclaryd, I kept weight decay on a schedule. After trying various combinations, I found that the same schedule as learning rate is quite good, so I kept the previous calculation of
effective_weight_decay = learning_rate * weight_decay. Scheduled weight decay improves performance on CWD by 10-15 steps.The choice of
wd=1.2is well tuned. In practice, it actually corresponds to startingeffective_weight_decay = 1.2 x 0.03 = 0.036.Cautious weight decay might be better called "masked decoupled weight decay". While it should be an unbiased estimator, I noticed that this weight decay has a very different training dynamic than the baseline:
In particular, we find that CWD has higher validation loss for the majority of the training run. There is an inflection point when the learning rate decreases, and CWD only "catches up" to the baseline in the final steps of training. I noticed this dynamic irrespective of whether WD is placed on a schedule.
Parameters under CWD have mean square magnitude
<20%of the magnitude under the baseline. I found this pattern consistently for both MLP and ATTN parameters.I found that the condition number after CWD is virtually identical the the condition number after NorMuon:
I believe this PR opens the door for rich future work, including tuning the WD schedule and CWD for Adam.