-
Notifications
You must be signed in to change notification settings - Fork 532
New WR: cautious weight decay on Adam (-1.1 seconds) #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
maybe the natural value could be calculated using some sort of moving average... |
|
I decided to do this attempt as I was confused by the code in Muon which multiplies the weight-decay factor by the LR twice (I thought it's a bug). I discovered why it's done when reading the readme by @varunneal (scheduling for WD factor), not before I found that I could get roughly the same score without this scheduling (I set the factor for NorMuon a bit higher, at a constant 0.042, and got pretty much the same results). I'm unconvinced that WD scheduling is necessary at the moment, but I did not run the necessary tuning in order to get the correct values without it, so I'll leave this for sometime in the future. |
|
There's some theoretical motivation for multiplying by LR twice https://arxiv.org/abs/2512.08217 but ultimately imo the schedule's important just for decaying wd to 0 by the end of training. I found it effective on Muon. Not sure it's as useful on adam because of the nature of the embeddings |
3.287 is very strange but of course possible as the tail end of a normal distribution. Hopefully it is not being caused by cwd |
Looks like the runs are higher variance in general. 3.272 is absurdly low and would never happen before. Better understanding what is causing this could unlock a couple seconds. |
|
@ClassicLarry yeah std dev of 0.0036 turns out to be like 3 times higher than what the present record is, or what most records are. I'm looking into whether this is caused by fp8 scales being improperly calibrated for cwd |
Good idea. The embed and lm_head are quite different. 75x different learning rate, and each embed only activates like 1/50,000 of the time. I doubt these should have the same decay rate. I haven't validated the runtime yet, but I'm confident the record will hold and this is worth merging in. But I expect I will then prefer testing new changes by temporarily removing anything that introduces high variance, like this change. I wonder how CWD is impacting sparse embeddings. If an embedding does not show up in a batch it will have a gradient of zero and mask = (update * p_slice) >= 0 will drive that embedding to zero. We might be better off changing '>=' to '>' so that sparse embeddings maintain their size. On the topic of the scalars, these might be better with different learning rates. Some end up around 25 and some stay between 0 and 1. When I was testing updating the x0 lambda I was finding loss changed on the configuration by a decent amount. I am curious how CWD is affecting the lm_head vectors of the 300 tokens that never occur during training. |
|
The current decay weight of 0.005 gives So there may be effectively no impact to the lm_head. Planning to look more into what this is actually doing. If we can replicate the 3.272 scenario its another couple seconds off. |
|
Confirmed 1.1s decrease, updating main readme at 127.7s to maintain 1.1s gap from prior record. |
I read the CWD paper a bit more carefully and tried I tried a few things on Adam that left me thinking that all the value is in dealing with sparse updates correctly and not the cautious part at all:
It seems to me like the main insight of the CWD paper is to ensure that weight-decay doesn't flip the sign of the update to any specific parameter, so we're still optimizing the original function and not a surrogate. I think one can be less cautious, so weight-decay more, and still get the same objective. Something along the lines of Update: my less-cautious WD mask seems to be a beneficial change - I'll rebase once the <2min record is merged and work on top of it, should allow dropping some more steps, but it needs a nicer parametrization as even in Adam, it measurably slows down steps. |
I followed the recommendation of @varunneal in his cautious weight decay and implemented it for Adam. No run-time cost as Adam run-time is dominated by cross-GPU communication. Reduced 20 steps, there's probably room for more.
I disabled weight-decay on the scalars since it seems to me like for some of them, the 'natural' value is not 0 and it would make more sense to do cautious weight decay relative to the 'natural' value. Left this as future work for now.
A small note: The first run had an exceptionally large loss which skewed the mean (2.87), discounting that run we would have a much lower average. I noticed that baseline runs had relatively high but valid losses
[3.2793, 3.2778, 3.2802, 3.279]