Skip to content

Conversation

@shenberg
Copy link
Contributor

I followed the recommendation of @varunneal in his cautious weight decay and implemented it for Adam. No run-time cost as Adam run-time is dominated by cross-GPU communication. Reduced 20 steps, there's probably room for more.

I disabled weight-decay on the scalars since it seems to me like for some of them, the 'natural' value is not 0 and it would make more sense to do cautious weight decay relative to the 'natural' value. Left this as future work for now.

image

A small note: The first run had an exceptionally large loss which skewed the mean (2.87), discounting that run we would have a much lower average. I noticed that baseline runs had relatively high but valid losses [3.2793, 3.2778, 3.2802, 3.279]

@linux-leo
Copy link

maybe the natural value could be calculated using some sort of moving average...

@shenberg
Copy link
Contributor Author

I decided to do this attempt as I was confused by the code in Muon which multiplies the weight-decay factor by the LR twice (I thought it's a bug). I discovered why it's done when reading the readme by @varunneal (scheduling for WD factor), not before I found that I could get roughly the same score without this scheduling (I set the factor for NorMuon a bit higher, at a constant 0.042, and got pretty much the same results). I'm unconvinced that WD scheduling is necessary at the moment, but I did not run the necessary tuning in order to get the correct values without it, so I'll leave this for sometime in the future.

@varunneal
Copy link
Contributor

There's some theoretical motivation for multiplying by LR twice https://arxiv.org/abs/2512.08217

but ultimately imo the schedule's important just for decaying wd to 0 by the end of training. I found it effective on Muon. Not sure it's as useful on adam because of the nature of the embeddings

@varunneal
Copy link
Contributor

varunneal commented Dec 18, 2025

The first run had an exceptionally large loss which skewed the mean

3.287 is very strange but of course possible as the tail end of a normal distribution. Hopefully it is not being caused by cwd

@ClassicLarry
Copy link
Collaborator

The first run had an exceptionally large loss which skewed the mean

3.287 is very strange but of course possible as the tail end of a normal distribution. Hopefully it is not being caused by cwd

Looks like the runs are higher variance in general. 3.272 is absurdly low and would never happen before. Better understanding what is causing this could unlock a couple seconds.

@varunneal
Copy link
Contributor

@ClassicLarry yeah std dev of 0.0036 turns out to be like 3 times higher than what the present record is, or what most records are. I'm looking into whether this is caused by fp8 scales being improperly calibrated for cwd

@ClassicLarry
Copy link
Collaborator

ClassicLarry commented Dec 19, 2025

@ClassicLarry yeah std dev of 0.0036 turns out to be like 3 times higher than what the present record is, or what most records are. I'm looking into whether this is caused by fp8 scales being improperly calibrated for cwd

Good idea. The embed and lm_head are quite different. 75x different learning rate, and each embed only activates like 1/50,000 of the time. I doubt these should have the same decay rate. I haven't validated the runtime yet, but I'm confident the record will hold and this is worth merging in. But I expect I will then prefer testing new changes by temporarily removing anything that introduces high variance, like this change.

I wonder how CWD is impacting sparse embeddings. If an embedding does not show up in a batch it will have a gradient of zero and mask = (update * p_slice) >= 0 will drive that embedding to zero. We might be better off changing '>=' to '>' so that sparse embeddings maintain their size.

On the topic of the scalars, these might be better with different learning rates. Some end up around 25 and some stay between 0 and 1. When I was testing updating the x0 lambda I was finding loss changed on the configuration by a decent amount.

I am curious how CWD is affecting the lm_head vectors of the 300 tokens that never occur during training.

@ClassicLarry
Copy link
Collaborator

The current decay weight of 0.005 gives
(.00875)^20.005 = 0.0018 decay per step for embed
(.008)^2*0.005 = 3.2e-7 decay per step for lm_head.

So there may be effectively no impact to the lm_head. Planning to look more into what this is actually doing. If we can replicate the 3.272 scenario its another couple seconds off.

@ClassicLarry
Copy link
Collaborator

Confirmed 1.1s decrease, updating main readme at 127.7s to maintain 1.1s gap from prior record.

@ClassicLarry ClassicLarry merged commit 49465cc into KellerJordan:master Dec 20, 2025
@shenberg
Copy link
Contributor Author

shenberg commented Dec 24, 2025

@ClassicLarry yeah std dev of 0.0036 turns out to be like 3 times higher than what the present record is, or what most records are. I'm looking into whether this is caused by fp8 scales being improperly calibrated for cwd

Good idea. The embed and lm_head are quite different. 75x different learning rate, and each embed only activates like 1/50,000 of the time. I doubt these should have the same decay rate. I haven't validated the runtime yet, but I'm confident the record will hold and this is worth merging in. But I expect I will then prefer testing new changes by temporarily removing anything that introduces high variance, like this change.

I wonder how CWD is impacting sparse embeddings. If an embedding does not show up in a batch it will have a gradient of zero and mask = (update * p_slice) >= 0 will drive that embedding to zero. We might be better off changing '>=' to '>' so that sparse embeddings maintain their size.

On the topic of the scalars, these might be better with different learning rates. Some end up around 25 and some stay between 0 and 1. When I was testing updating the x0 lambda I was finding loss changed on the configuration by a decent amount.

I am curious how CWD is affecting the lm_head vectors of the 300 tokens that never occur during training.

I read the CWD paper a bit more carefully and tried I tried a few things on Adam that left me thinking that all the value is in dealing with sparse updates correctly and not the cautious part at all:

  1. Giving an additional margin such that weight-decay will never flip the sign of an element if the update wouldn't do it itself. Seemed maybe a teeny bit better. I thought maybe I had something here, but...
  2. Reversing the sign on the mask, so mask = (update * p_slice) < 0, seemed equally ok...
  3. Assuming that it's only sparse gradients that are the problem, so mask = (update * p_slice) != 0 - also seems fine.

It seems to me like the main insight of the CWD paper is to ensure that weight-decay doesn't flip the sign of the update to any specific parameter, so we're still optimizing the original function and not a surrogate. I think one can be less cautious, so weight-decay more, and still get the same objective. Something along the lines of mask = ((update * p_slice) > 0) | ((update * p_slice) < p_slice.square() * (-eff_weight_decay * lr)) ("if directions agree, or if weight-decay has smaller magnitude than optimizer update"). No spot instances available at the moment so I'll see in a bit if this helps at all. Not optimistic due to the non-zero update experiment. Maybe this line of reasoning makes more sense for Muon, though the more expensive mask may make it not-worthwhile there. Adam being comms-constrained makes it ripe for experimentation.

Update: my less-cautious WD mask seems to be a beneficial change - I'll rebase once the <2min record is merged and work on top of it, should allow dropping some more steps, but it needs a nicer parametrization as even in Adam, it measurably slows down steps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants