New Record: Multi-token prediction and Untie LM Head 2/3rds through training (119.76 seconds) #178
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This record implements:
There are additionally the following minor changes
>=to>(contribution from @ClassicLarry)Timing and Validation
This PR has 80 fewer steps than PR177 at a slightly higher step time.
Previous record (timed on same machine):
These timings show$\approx 3$ seconds of improvement.
Thank you to Prime Intellect for sponsoring my research with GPU credits.
Multi-token prediction
This record implements Multi-token Prediction without adding any additional parameters. It uses a weighted average of the cross entropy loss over the next$k$ tokens as the total step loss. In-line with the batch size schedule, MTP follows three phases:
[1, 0.5, 0.25]and decay to[1, 0.5, 0.0].[1, 0.5]and decay to[1, 0.0].In shorthand, this schedule can be described as
[1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1].I experimented with various other ideas for Multi-token prediction, including by adding trainable weights to distinguish the $k$th token predictions. Ultimately, a parameter-free approach seems effective. Intuitively, this naive approach may be effective because early training is learning the$n$ -gram distribution. For this task, MTP increases the signal, and can be considered a cheap proxy of increasing the batch size.
If there additional ablations or experiments you'd like me to run, please let me know in this PR's discussion.
Untying the LM Head and Embed weight
This idea goes back to a comment by @ClassicLarry in PR#175:
I found that the 75x learning rate is probably due more to the high initial weight initialization on the original version of the embedding weights rather than the sparsity of the updates. I found that the untied embedding weights did not benefit from learning rate a multipler, though it may be possible to include, perhaps by using a schedule, or lowering the magnitude of the weights.
Additional Details
>=to>shows minor improvement in Adam.TODO: Include ablations for the above