Skip to content

Conversation

@ClassicLarry
Copy link
Collaborator

@ClassicLarry ClassicLarry commented Dec 14, 2025

New WR 128.8s: Partial Key Offset

Five updates, for roughly 2.4s improvement:

  1. Implement a partial key offset for the long sliding windows (this change enables step count decrease that is responsible for majority of improvement). Maybe this is novel?
  2. Merge the x_lambda*x + x0_lambda*x0 for layer 0 into (x_lambda+x0_lambda)*x. And clean up the code to properly represent the 11 layer model.
  3. Drop 50 steps. (60ms per step)
  4. Align the batch size schedule to update exactly when the window size updates (0.33!=1/3)
  5. Zero out the initial value embeddings. Very minor impact, but may give lower variance and prefer zero init as lowest assumption config.

The partial key offset is only applied to the stationary head dims (32-64 and 96-128). This was found to perform better than applying it to all dims. This approach gives the queries more freedom to attend to multiple positions at once through a single key. I tested applying to only a subset of heads and got worse results. The lowest loss was achieved when applying it to every layer, but the cost:speed ratio seems best when only applying to the long windows, which are the ones primarily responsible for induction.

# shift keys forward for the stationary head dims. Enables 1-layer induction.
k[:, 1:, :, self.head_dim//4:self.head_dim//2] = k[:, :-1, :, self.head_dim//4:self.head_dim//2]
k[:, 1:, :, self.head_dim//4+self.head_dim//2:] = k[:, :-1, :, self.head_dim//4+self.head_dim//2:]
Picture4

Timing and Validation

import scipy.stats
import torch

losses = [3.2788,3.2774,3.2786,3.2792,3.2762,3.2769,3.2781,3.2778,3.2761,3.2783,3.2809]
times = [128.892,128.907,128.912,128.844,128.822,128.869,128.818,128.882,128.886,128.95,128.946]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0004

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (tensor(0.0014), tensor(3.2780))

print("time:", torch.std_mean(torch.tensor(times)))
# time: (tensor(0.0443), tensor(128.8844))

retiming prior record: 131.2: [131.270,131.241,131.213]
(appears I got a slightly slower machine this time)

@ClassicLarry
Copy link
Collaborator Author

Looking through the literature, I see similar intuitions here: https://arxiv.org/pdf/2411.19574

@varunneal
Copy link
Contributor

varunneal commented Dec 14, 2025

Woah really cool. Is this intuition right?

The attention score for token $i$ attending to token $j$ is typically score(i -> j) = softmax(q[i] * k[j]). What the induction is doing is making it so that score(i -> j) is instead a function of q[i], k[j] and k[j-1. How we do that is by splicing in half the dims (the second and fourth quarter) from k[j-1] into the dims from k[j].

Really neat idea. I wonder if there is additionally room or benefit for this type of splicing-based induction.

@ClassicLarry
Copy link
Collaborator Author

Woah really cool. Is this intuition right?

The attention score for token i attending to token j is typically score(i -> j) = softmax(q[i] * k[j]). What the induction is doing is making it so that score(i -> j) is instead a function of q[i], k[j] and k[j-1. How we do that is by splicing in half the dims (the second and fourth quarter) from k[j-1] into the dims from k[j].

Really neat idea. I wonder if there is additionally room or benefit for this type of splicing-based induction.

ya. Using half the dims lets the query decide what it wants to do. If the query wants to "Find a token like me and pull in the value for the token that follows it", then it can attend to the offset dims. If the query wants to "Attend to the bos_token/Find a Verb/etc", then it can attend to the normal dims. Or it can do a blend of both.

@varunneal
Copy link
Contributor

did you try adding k[j-1] onto some dims instead of splicing it directly?

@ClassicLarry
Copy link
Collaborator Author

did you try adding k[j-1] onto some dims instead of splicing it directly?

I did not. There’s a pretty big design space of ideas not tried.

@linux-leo
Copy link

linux-leo commented Dec 15, 2025

I'd shift v too tbh. and experiment with data dependent vs data independent shift, both have been tried before in this speedrun.

upon taking a closer look, the splicing approach is a form of data dependent shift, so this is already the best approach

@ClassicLarry
Copy link
Collaborator Author

I'd shift v too tbh. and experiment with data dependent vs data independent shift, both have been tried before in this speedrun.

upon taking a closer look, the splicing approach is a form of data dependent shift, so this is already the best approach

Could be tested for sure. But I’m not sure conceptually what the benefit of shifting v would be.

@varunneal
Copy link
Contributor

varunneal commented Dec 15, 2025

Before you merge might be worth it to clean up the extra import block around lines 450. I believe it was added in the previous record

import torch
from torch import Tensor
import torch.distributed as dist
from collections import defaultdict

@ClassicLarry ClassicLarry merged commit 28fda1e into KellerJordan:master Dec 16, 2025
@YouJiacheng
Copy link
Contributor

the most similar literature (2020): https://github.com/BlinkDL/minGPT-tuned/blob/81807b927f8794d3dca45eda9cd25bf3eb035568/mingpt/model.py#L67
Interestingly Bo also used half dim lol.

thib-s added a commit to thib-s/modded-nanogpt-turbomuon that referenced this pull request Dec 16, 2025
@chrisjmccormick
Copy link
Contributor

@ClassicLarry Nice work! It's kind of remarkable to me that these cross-token tricks aren't more expensive. At a glance it sounds like a ton of memory shuffling. Maybe it is and it's just worth it?

Do induction heads typically require two-layers to work? Thus the 1-layer induction head comment?

Also, really like the decision to renumber the layers 😊

@ClassicLarry
Copy link
Collaborator Author

@ClassicLarry Nice work! It's kind of remarkable to me that these cross-token tricks aren't more expensive. At a glance it sounds like a ton of memory shuffling. Maybe it is and it's just worth it?

Do induction heads typically require two-layers to work? Thus the 1-layer induction head comment?

Also, really like the decision to renumber the layers 😊

https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated

image This post shows how induction typically works. It looks quite convoluted for such a simple idea of “repeat what followed my last occurrence”

@chrisjmccormick
Copy link
Contributor

Can I ask how you went about decreasing the step count? i.e., what indication did you get that you could afford to reduce the step count, and then is it just a bit of trial and error to see how many you can shed?

@ClassicLarry
Copy link
Collaborator Author

Can I ask how you went about decreasing the step count? i.e., what indication did you get that you could afford to reduce the step count, and then is it just a bit of trial and error to see how many you can shed?

Mostly trial and error. If I ever see a loss above 3.282 I will add 10-20 steps. If I ever get a loss below 3.275 I know I need to decrease step count.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants