New WR 151.5s: Drop first attn layer, extend all long windows for validation, update schedule #131
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR builds on all recent WR improvements including PR #130. Updates:
Several factors led to dropping the first attention layer:
Reason for iteration_extension:
Future Opportunities
This change bring the total number of [4,768,768] attention variables to 10. There are 22 MLP variables of size [768x4,768]. In Muon attention is getting batched such that 6/16ths on the gradient calcs are on padding tokens. There may be a way to move 2 of the attention variables into the MLP batch, such that MLP is 24/24 and attn is 8/8, instead of MLP being 22/24 and attn being 10/16.
Investigating Muon for 1D variables
Currently the attention gates and smear gate are passed into Muon. From light inspection, the implementation of newton schulz appears to roughly apply F.normalize(x, p=2, dim=-1) for 1d variables. This normalization makes all steps cover roughly the same distance, regardless of the gradient. So for 1d variables Muon turns into an exponential smoothing over prior gradients, where each step is normalized to be roughly the same size. This seems somewhat reasonable. Swapping these variables over to Adam gave roughly a 0.5s runtime increase and no improvement in loss. Directly replacing newton schulz with F.normalize(x, p=2, dim=-1) for these variables showed slightly worse performance. I do not understand the theory here yet, but empirically the performance is good.
Validation:
Code syntax/naming was lightly refactored after performing validation runs. Loss is roughly 0.001 lower than prior record, which is roughly equal to 1s.