-
Notifications
You must be signed in to change notification settings - Fork 532
New WR 128.8s: Partial Key Offset (-2.4s, -50 steps) #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New WR 128.8s: Partial Key Offset (-2.4s, -50 steps) #169
Conversation
|
Looking through the literature, I see similar intuitions here: https://arxiv.org/pdf/2411.19574 |
|
Woah really cool. Is this intuition right? The attention score for token Really neat idea. I wonder if there is additionally room or benefit for this type of splicing-based induction. |
ya. Using half the dims lets the query decide what it wants to do. If the query wants to "Find a token like me and pull in the value for the token that follows it", then it can attend to the offset dims. If the query wants to "Attend to the bos_token/Find a Verb/etc", then it can attend to the normal dims. Or it can do a blend of both. |
|
did you try adding k[j-1] onto some dims instead of splicing it directly? |
I did not. There’s a pretty big design space of ideas not tried. |
|
I'd shift v too tbh. upon taking a closer look, the splicing approach is a form of data dependent shift, so this is already the best approach |
Could be tested for sure. But I’m not sure conceptually what the benefit of shifting v would be. |
|
Before you merge might be worth it to clean up the extra import block around lines 450. I believe it was added in the previous record |
|
the most similar literature (2020): https://github.com/BlinkDL/minGPT-tuned/blob/81807b927f8794d3dca45eda9cd25bf3eb035568/mingpt/model.py#L67 |
|
@ClassicLarry Nice work! It's kind of remarkable to me that these cross-token tricks aren't more expensive. At a glance it sounds like a ton of memory shuffling. Maybe it is and it's just worth it? Do induction heads typically require two-layers to work? Thus the 1-layer induction head comment? Also, really like the decision to renumber the layers 😊 |
https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated
This post shows how induction typically works. It looks quite convoluted for such a simple idea of “repeat what followed my last occurrence”
|
|
Can I ask how you went about decreasing the step count? i.e., what indication did you get that you could afford to reduce the step count, and then is it just a bit of trial and error to see how many you can shed? |
Mostly trial and error. If I ever see a loss above 3.282 I will add 10-20 steps. If I ever get a loss below 3.275 I know I need to decrease step count. |

New WR 128.8s: Partial Key Offset
Five updates, for roughly 2.4s improvement:
x_lambda*x + x0_lambda*x0for layer 0 into (x_lambda+x0_lambda)*x. And clean up the code to properly represent the 11 layer model.The partial key offset is only applied to the stationary head dims (32-64 and 96-128). This was found to perform better than applying it to all dims. This approach gives the queries more freedom to attend to multiple positions at once through a single key. I tested applying to only a subset of heads and got worse results. The lowest loss was achieved when applying it to every layer, but the cost:speed ratio seems best when only applying to the long windows, which are the ones primarily responsible for induction.
Timing and Validation
retiming prior record: 131.2: [131.270,131.241,131.213]
(appears I got a slightly slower machine this time)