Skip to content

Conversation

@varunneal
Copy link
Contributor

@varunneal varunneal commented Dec 22, 2025

This record implements:

  • a form of multitoken prediction
  • Untying the LM Head and Embed weight at 2/3rds through training (contribution from @ClassicLarry)

There are additionally the following minor changes

  • Changing CWD on Adam from >= to > (contribution from @ClassicLarry)
  • Doubling the weight initialization for the attention heads
  • Slightly decreasing the magnitude of the second learning rate bump

Timing and Validation

This PR has 80 fewer steps than PR177 at a slightly higher step time.

import scipy.stats
import torch

losses = [3.2783, 3.2809, 3.2784, 3.2783, 3.2787, 3.2776, 3.2781, 3.2780, 3.2768, 3.2798, 3.2789, 3.2797, 3.2795, 3.2794]
times = [119.694, 119.820, 119.762, 119.794, 119.811, 119.760, 119.801, 119.923, 119.677, 119.754, 119.726, 119.865, 119.685, 119.578]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0003

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0010, mean=3.2787)

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0867, mean=119.7607)

Previous record (timed on same machine):

import scipy.stats
import torch

times = [122.445, 122.882, 122.913, 122.940, 122.939]

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.2131, mean=122.8238)

These timings show $\approx 3$ seconds of improvement.

Thank you to Prime Intellect for sponsoring my research with GPU credits.

Multi-token prediction

This record implements Multi-token Prediction without adding any additional parameters. It uses a weighted average of the cross entropy loss over the next $k$ tokens as the total step loss. In-line with the batch size schedule, MTP follows three phases:

  1. First phase: weighted average over 3 tokens. The weights start at [1, 0.5, 0.25] and decay to [1, 0.5, 0.0].
  2. Second phase: weighted average over 2 token MTP. The weights start at [1, 0.5] and decay to [1, 0.0].
  3. Third/extension phase: Regular next-token prediction.

In shorthand, this schedule can be described as [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1].

I experimented with various other ideas for Multi-token prediction, including by adding trainable weights to distinguish the $k$th token predictions. Ultimately, a parameter-free approach seems effective. Intuitively, this naive approach may be effective because early training is learning the $n$-gram distribution. For this task, MTP increases the signal, and can be considered a cheap proxy of increasing the batch size.

If there additional ablations or experiments you'd like me to run, please let me know in this PR's discussion.

Untying the LM Head and Embed weight

This idea goes back to a comment by @ClassicLarry in PR#175:

My shaky hypothesis here is that since the embed gradient is very sparse, it benefits from tagging along with the lm_head gradient early, since there are some core similarities between token relationships in lm_head space and embed space. This sparsity was previously handled by scaling up the lr 75x, which is less stable than the approach here. But later on in training the embed and lm_head are fundamentally different, and benefit from having their own representations.

I found that the 75x learning rate is probably due more to the high initial weight initialization on the original version of the embedding weights rather than the sparsity of the updates. I found that the untied embedding weights did not benefit from learning rate a multipler, though it may be possible to include, perhaps by using a schedule, or lowering the magnitude of the weights.

Additional Details

  • The change for CWD from >= to > shows minor improvement in Adam.
  • Doubling the weight initialization for the attention heads shows empirical improvement, and now corresponds to Kaiming/Xavier initialization for QKV weights assuming they are treated as square $(768, 768)$ parameters.
  • Slightly decreasing the magnitude of the second learning rate bump seems to slightly decrease variance. This idea takes a bit of inspiration #from the graphs of PR#177.
  • Matching PR#177, I'm using nightly 12/10 and Python version 3.12.
  • My timing for this record seems a bit slower than both PR#175 and PR#177.

TODO: Include ablations for the above

@chrisjmccormick
Copy link
Contributor

Under 2 minutes, you did it!! 😃

@shenberg
Copy link
Contributor

Re: CWD, I was thinking of doing the same change, my logic was that for embedding params with sparse gradients, it undoes 'cautious' part of CWD on gradient == 0 and becomes regular weight decay for unused tokens.

MTP: It's a really cool method! Did you experiment with DeepSeek v3 style MTP modules (or this slightly simpler approach)? I would have hoped for a method that's "always on" for MTP, though this method being parameter-free and really cheap to compute is very appealing. A baseline ablation , decaying label smoothing, maybe? Or as a poor-man's approximation, fixed secondary and tertiary token targets? (The whole idea seems reminiscent of online label smoothing)

Two questions I have are:

  1. How did you come up with the schedule?
  2. Does the noise from predicting across documents matter? (it's a small amount of tokens so probably not?)

@varunneal
Copy link
Contributor Author

@shenberg I didn't see that paper but I was definitely inspired by some of the work on MTP, especially the new minimax model.

For a, I essentially wanted to follow the tripartite batch size schedule so I had some constraints on the search space. Using the shorthand notation for the search space shows there's not really too many variables to optimize over. I had to choose the number of tokens in each phase (eg 2 or 3 or 4) and then what weights each token should get. For the latter, the 1, 1/2, 1/4, ... was my first guess, following the prophet net intuition, though I ran some experiments to confirm it was better than some alternatives.

Obviously there's a lot of ways to make an MTP schedule but I like to narrow the search space as much as possible and see if I can get something working. Hopefully future optimizations can be found.

For b you raise a good point. We could try using a cheap mask. Maybe a future pr idea would be to mask out all train losses that follow the bos token. Though maybe this is not useful since learning the most common starting document n gram is important at val. Perhaps only apply the mask for parts of training.

@ClassicLarry
Copy link
Collaborator

Will merge in a couple days at 119.3s (-2.9s)

@ClassicLarry ClassicLarry merged commit 9060ff4 into KellerJordan:master Dec 25, 2025
@varunneal
Copy link
Contributor Author

By the way I got the following ablations:

Configuration Val Losses Mean Std Dev Δ from Baseline
baseline 3.2781, 3.2781, 3.2785, 3.2815, 3.2781 3.2789 0.0015
no_mtp 3.2844, 3.2858, 3.2848, 3.2870, 3.2853 3.2855 0.0010 +0.0066
no_split_embed 3.2811, 3.2816, 3.2819, 3.2804, 3.2811 3.2812 0.0006 +0.0023
no_small_lr_bump 3.2808, 3.2785, 3.2779, 3.2788, 3.2794 3.2791 0.0011 +0.0002
no_cwd_strict 3.2802, 3.2784, 3.2768, 3.2790, 3.2787 3.2786 0.0012 -0.0003
no_double_attn_init 3.2793, 3.2766, 3.2796, 3.2776, 3.2791 3.2784 0.0013 -0.0005

In previous tests I saw some benefit from no_double_attn_init but it looks like whatever benefit is there is within a noise tolerance. Basically the same for all three minor changes added in this record

cc: @ClassicLarry

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants