New Record - Smoothed Scalars and Smear Gate #177
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Smoothed Scalars and Smear Gate
This update includes:
Smear Gate Updates
I've been experimenting with logging the values of the scalar parameters in the model over the course of training, and am seeing some interesting behavior relating to the schedule transitions, and to the smear gate weights in particular.
For reference: each input token to our model is actually a sum of itself and the prior token. The multiplier on the prior token is determined by the smear gate. Token
tdetermines how much of tokent-1to add to itself using the first 12 dimensions of its embedding, which are the input to the smear gate.Here are a few of the smear gate weight values from recent records.
For the key shift record (in blue), the smear gate values were small and noisy, but stable, until we hit the first shcedule transition around ~700 steps, where the batch size, window size, and learning rate increase.
Then, with cautious weight decay added to Adam (in red) the weight values were again small and noisy, but stayed relatively stable throughout training.
With the most recent record, re-tying the embeddings (in green), the noise appears reduced, and the batch size jump doesn't destabilize it.
(Note: The long ramp corresponds to Muon's momentum warmup period of 300 steps.)
Moving the smear gate weights to Adam and reducing the learning rate to 0.01 brought it in line with the value range it has had before the destabilization point, and also smoothed it strongly.
I haven't worked through the math, but I could imagine that there's a strong relationship between the embedding table and the smear gate. Even though its only 12 dimensions, its output is applied to every position of every embedding in the sequence.
Increased Smoothing on Scalars
The scalars in the model seem to more-or-less follow predictable / consistent trajectories across runs, but are noisy.
My thought was that, given how clear the intended trajectory of the scalars is, it might be beneficial to find a way to make them follow that path with less noise.
In order to increase the beta values on just the scalars, I landed on creating a second instance of DistAdam specifically for them.
Also, I found that the
x0_lambdas(below) are unique in that they are already very smooth in comparison to the other scalars.Additional momentum altered their course, so to avoid that, I separated them out from the
scalarsblock and kept them with the main instance of Adam.DistAdam All Reduce
The changes to the parameters (separating x0_lambdas, and moving the smear gate) complicates the sharding logic in Adam.
To avoid that, and because it "seems sensible", I adjusted DistAdam to use
all_reducefor parameters with fewer than 1,024 values. This means that each GPU receives the full set of gradients for these weights and does redundant work to update them. The benefit is that we don't need to pad to the world size, and we don't need to issue a gather operation for them at the end.I figured I might have added overhead overall, though, by splitting off x0_lambdas. It doesn't appear to have made things worse, though:
(This compares to the prior record at the same number of steps)
Pausing Scalar Updates
I tried several approaches, such as gradient clipping, to prevent the big swings that occur in the values when the schedule changes.
I haven't been able to smooth them out fully, but what's seemed to work best was to zero the gradients for the scalars for 40 steps, while preserving their momentum buffers. Preserving the buffers means they still have that healthy momentum when jumping back into the turbulence.
Further Work
Other Changes
pip install --pre torch==2.10.0.dev20251210+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126wandb project
You can explore all of these plots and the other scalar values in the wandb project, here.
Towards the bottom I also included dim0 of all attention heads, and some dimensions for a handful of token embeddings (I only added these in my run, though).