Skip to content

Conversation

@JR-1991
Copy link

@JR-1991 JR-1991 commented Dec 18, 2025

Following up on the suggestion we talked about a few months back, this pull request introduces support for retrieving all hidden layers of the ESM2 model for layer sweeps and inspections. To achieve this, this PR extends the ESM2Result class and updates the model’s forward pass to gather and return these intermediate outputs. Tests have also been added to ensure the new features work as expected.

Model output enhancements:

  • The ESM2Result class now includes a new .all_hidden attribute, which contains the hidden representations from all layers (including the final layer).
  • The model's _call method is updated to collect and return the outputs from all layers in the .all_hidden field.

Testing:

  • Added a new test, test_all_hidden_output, to verify that the .all_hidden output has the correct shape and is present when calling the model with both tokenized and string inputs.

The ESM2Result class now includes an all_hidden attribute containing outputs from all layers, not just the final layer. The ESM2 model's __call__ method is updated to collect and return these intermediate representations, providing more detailed model output for downstream analysis.
Changed ESM2Result.all_hidden from a list to a Float array with shape (num_layers, length, embed_size). Added a test to verify the shape of all_hidden output and ensure consistency for both token and string inputs.
Introduces a test to verify the shape of the all_hidden output from the ESM2 model for both tokenized and string inputs. Ensures that all_hidden has the expected dimensions based on the number of layers, sequence length, and embedding size.
The __len__ method, which returned the number of layers, has been removed from the ESM2 class. This simplifies the class interface and removes an unused or unnecessary method.

hidden: Float[Array, "length embed_size"]
logits: Float[Array, "length alphabet_size"]
all_hidden: Float[Array, "num_layers length embed_size"]
Copy link
Owner

@patrick-kidger patrick-kidger Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

Looking at this line, I think it would be more efficient to represent the hidden state as a list[Float[Array, "length embed_size"]].

The typical use-case in which a user needs only a few of the layers would then make it possible for the compiler to DCE the remaining elements of the list.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick feedback! That makes sense, will update the PR :)

Modified ESM2Result to store all_hidden as a list of arrays instead of a single stacked array. Updated tests to check for list type and correct shapes for each element, ensuring compatibility with the new structure.
@JR-1991
Copy link
Author

JR-1991 commented Dec 19, 2025

@patrick-kidger, I’ve updated the PR to use a list type instead of a 3-dimensional tensor, as you suggested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants