-
Notifications
You must be signed in to change notification settings - Fork 3
Feature: Returning ESM-2 hidden layers #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The ESM2Result class now includes an all_hidden attribute containing outputs from all layers, not just the final layer. The ESM2 model's __call__ method is updated to collect and return these intermediate representations, providing more detailed model output for downstream analysis.
Changed ESM2Result.all_hidden from a list to a Float array with shape (num_layers, length, embed_size). Added a test to verify the shape of all_hidden output and ensure consistency for both token and string inputs.
Introduces a test to verify the shape of the all_hidden output from the ESM2 model for both tokenized and string inputs. Ensures that all_hidden has the expected dimensions based on the number of layers, sequence length, and embedding size.
The __len__ method, which returned the number of layers, has been removed from the ESM2 class. This simplifies the class interface and removes an unused or unnecessary method.
esm2quinox/_esm2.py
Outdated
|
|
||
| hidden: Float[Array, "length embed_size"] | ||
| logits: Float[Array, "length alphabet_size"] | ||
| all_hidden: Float[Array, "num_layers length embed_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
Looking at this line, I think it would be more efficient to represent the hidden state as a list[Float[Array, "length embed_size"]].
The typical use-case in which a user needs only a few of the layers would then make it possible for the compiler to DCE the remaining elements of the list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick feedback! That makes sense, will update the PR :)
Modified ESM2Result to store all_hidden as a list of arrays instead of a single stacked array. Updated tests to check for list type and correct shapes for each element, ensuring compatibility with the new structure.
|
@patrick-kidger, I’ve updated the PR to use a |
Following up on the suggestion we talked about a few months back, this pull request introduces support for retrieving all hidden layers of the ESM2 model for layer sweeps and inspections. To achieve this, this PR extends the
ESM2Resultclass and updates the model’s forward pass to gather and return these intermediate outputs. Tests have also been added to ensure the new features work as expected.Model output enhancements:
ESM2Resultclass now includes a new.all_hiddenattribute, which contains the hidden representations from all layers (including the final layer)._callmethod is updated to collect and return the outputs from all layers in the.all_hiddenfield.Testing:
test_all_hidden_output, to verify that the.all_hiddenoutput has the correct shape and is present when calling the model with both tokenized and string inputs.