New WR 134.9s: Refine Skip Architecture, Better Lambda Init #159
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Three updates, for roughly 2s improvement:
I am not building on top of #155 because I have not yet been able to get that PR to run on 8H100. Once that one is compatible it can likely be merged in on top of this.
Details on Lambda Update
The block lambda structure of x = ax+bx0 can be unrolled across layers. Critically, this lambda does not get applied to the layer input, but to the actual residual stream.
Let R_i be output of layer i, then
contribution to prediction by layer i is proportional to a^(num_layers-i)
Initializing the lambda to 1.1 gives:
prediction = 1.1^10*R_1 + 1.1^9*R_2 + ... = 2.59 * R_1 + 2.35 * R_2 + 2.14 * R_3 + ... R_11Given this init, the first layer has 2.6x the impact to the prediction as the last layer. The initialization of this lambda can be thought of as a lever to bias the network early training towards the earlier or later layers.
Here are the final weights after training for a and b in x=ax+bx0 for the 12 layers
From layers 2 to 8 the first lambda ends around 0.5. This means that the final contribution of the 1st layer output to the prediction is muted by roughly 0.5^7, compared to the 8th layer. The residual stream is applying a sort of exponential smoothing over layer contributions.
Similar to the backout_lambda, this gives further evidence that the first 3/4 of the layers are functioning as context builders and the last 1/4 are functioning as predictors. The lambda enables each layer to use the context output from nearby layers, which then gets washed out after repeatedly applying 0.5 to the residual stream.
A secondary effect of this lambda ending up < 1 is that each MLP pays extra focus to its own attention layer, because the deweighting of the residual stream occurs before the attention output.
I previously tested an architecture where each module of each layer could dynamically set a unique weight to accept the contributions of every prior module from every prior layer. I saw that every MLP module would consistently give a large preferential focus to its own attention layer. At the final values, this lambda accomplishes a similar objective.
The implementation of this lambda in the repo is computationally efficient but conceptually misleading. The comments above do not give intuition as to why initializing to greater than 1 would be preferred, just that this parameter is perhaps a meaingful lever to bias training to early or later portions of the network.
Timing and Validation
retiming prior record: 137.3: [137.083,137.374,137.546]
(appears I got a slightly slower machine this time)