Skip to content

Commit e89cb13

Browse files
committed
Added logs and readme
1 parent 86ab0b2 commit e89cb13

11 files changed

+37944
-0
lines changed

records/track_1_short/2025-12-10_SALambdaOnWeights/15ef5eaf-56e1-40e1-9ddf-af010027c9dd.txt

Lines changed: 3787 additions & 0 deletions
Large diffs are not rendered by default.

records/track_1_short/2025-12-10_SALambdaOnWeights/167f03c6-4035-4d50-b2bc-1671c420e250.txt

Lines changed: 3787 additions & 0 deletions
Large diffs are not rendered by default.

records/track_1_short/2025-12-10_SALambdaOnWeights/54df456d-31a2-4fd5-9738-9008df956409.txt

Lines changed: 3787 additions & 0 deletions
Large diffs are not rendered by default.

records/track_1_short/2025-12-10_SALambdaOnWeights/5c0991a6-4ba8-4511-b600-0b26597f91d7.txt

Lines changed: 3787 additions & 0 deletions
Large diffs are not rendered by default.

records/track_1_short/2025-12-10_SALambdaOnWeights/717bc1fb-66cd-43dc-ad1a-b46f946ccb84.txt

Lines changed: 3787 additions & 0 deletions
Large diffs are not rendered by default.

records/track_1_short/2025-12-10_SALambdaOnWeights/7770e75a-c7ef-4c7c-bd77-94bff0152546.txt

Lines changed: 3787 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
This record changes the multiplication location of the Self-attention lambdas to pre-scale the QKV matrix instead of directly scaling V. Also, the warm-up process is fixed in order to correctly pre-compile all the code paths.
2+
3+
## Timing and Validation
4+
5+
A fixed overhead was removed by the compile fix, the rest of the gain was an increase in speed per time-step.
6+
7+
```
8+
import scipy.stats
9+
import torch
10+
11+
losses = [3.2762, 3.2785, 3.2789, 3.2774, 3.2769, 3.2775, 3.275, 3.2769, 3.2808, 3.2797]
12+
times = [131.233, 131.239, 131.225, 131.284, 131.273, 131.227, 131.325, 131.047, 131.236, 131.145]
13+
14+
print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
15+
# p=0.0014
16+
17+
print("losses:", torch.std_mean(torch.tensor(losses)))
18+
# losses: (tensor(0.0017), tensor(3.2778))
19+
20+
print("time:", torch.std_mean(torch.tensor(times)))
21+
# time: (tensor(0.0776), tensor(131.2234))
22+
```
23+
24+
Previous record (timed on the same machine):
25+
26+
```
27+
import scipy.stats
28+
import torch
29+
30+
baseline_times = [131.795, 131.707, 131.753]
31+
print("time:", torch.std_mean(torch.tensor(baseline_times)))
32+
# time: (tensor(0.0440), tensor(131.7517))
33+
```
34+
35+
This shows an improvement of $ \approx 0.5 $ seconds.
36+
37+
## Changing Self-Attention Lambdas
38+
39+
One of the architecture modifications performed on GPT-2 is having multiple separate embeddings, called value embeddings, for the tokens
40+
that are mixed directly into the self-attention V projections, with learned weights, lambdas.
41+
This is a straight-forward `v = v * sa_lambda[0] + ve * sa_lambda[1]`. The V projections are big (32768 by 768 elements on average),
42+
whereas in order to scale the V projections, we can intervene in multiple places where we'd need to do less work.
43+
The $ W_V $ weight matrix is only 768 by 768, so by changing $ \lambda (W_V x) $ to $ (\lambda W_V) x $ we can save work.
44+
45+
It's actually still not straight-forward that this would get us anything as we are running under `torch.compile()`,
46+
which could, in theory, fuse the scalar multiply into the $ W_V x $ matrix-multiply kernel by itself, which would hide the
47+
cost of the scalar multiplication with the memory accesses for the matrix-multiply. This doesn't happen, though, probably
48+
because for efficiency reasons, we keep a single matrix with QKVO and do the QKV projections as a single matrix-multiply.
49+
This does mean we actually pre-multiply the entire QKV matrix by our scalar, which loses efficiency but still works because
50+
the Q and K projections are immediately normed (thanks, QK-norm!).
51+
52+
**Note**: even in layers without value-embeddings, we scale the value embeddings, and this is probably beneficial for the model as
53+
we RMS-norm the input to the SA block, while the residual stream magnitude increases as we reach deeper layers in the network. This
54+
allows scaling the output of the SA block appropriately. In trained models, the lambdas generally grow with network depth.
55+
56+
### Minor Variations
57+
58+
All of the manipulations done on V projections are linear - we only do additions and multiplications by scalars,
59+
so there are multiple spots where we can scale V. It would maybe require re-scaling the value-embeddings lambda initialization
60+
a bit to make the algebra work out the same, but it "*shouldn't*" matter.
61+
62+
* Multiplying the $ W_O $ matrix, which increased the loss for reasons I did not understand (my successor did, however!)
63+
* Multiplying the output of the sparse-attention gates, which is actually the most efficient option in terms of total operation count.
64+
Also seems to have increased the loss.
65+
66+
## Warm-up Bug-fix
67+
68+
A previous record which improved the DistAdam compute-communication overlap did not update the `torch.compile()` warmup phase to cover
69+
all of the newly-added code-paths, which caused a slow-down in the 2nd iteration as a recompile had to happen.
70+
71+
The bug was discovered accidentally while looking for a fix to a problem where one of every few hundred training steps would become intermittently
72+
slow (the cause is not a recompile, nor is it an NCCL issue, though I looked exhaustively). The problem only disappeared once a `gc.collect()` was
73+
added before starting the timer for the run (a run with `gc.set_debug(gc.DEBUG_STATS)` showed that there's a memory leak somewhere and one gen-2 collection
74+
that does a stop-the-world collection which roughly fit the timeline).

0 commit comments

Comments
 (0)