-
Notifications
You must be signed in to change notification settings - Fork 669
Add Rotary Positional Embeddings #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adding rotary positional embeddings and associated tests. I also abuse this commit by adding the gitignore file
) | ||
|
||
# Outer product of theta and position index | ||
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, does this work with jit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something we have to worry about still?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We won't be needing jit compatibility, only torch.compile, since we'll be doing python inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, most comments are minor, my main question is around recomputation of the RoPE cache and how that works.
Thanks!
@@ -0,0 +1,85 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we remove these copyrights in OSS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually took this from the OSS code in Multimodal
@@ -0,0 +1,85 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. | |||
# All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, maybe make a subdirectory /components
for these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't quite follow - can you elaborate?
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
TODO: The implementation below can be made more efficient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any pointers as to how so?
set_rng_seed(0) | ||
|
||
|
||
class TestRotaryPositionEmbedding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to test the cache invalidation / recomputation?
for inference. | ||
""" | ||
seq_len = x.size(1) | ||
rope_cache = self.cache[:seq_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does the cache actually get invalidated if we exceed the seq_len and recomputed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We compute this with max_seq_len that the model supports and so in the current setting it wouldn't need to be invalidated. There are some corner cases for inference which I don't think I fully understand right now
) | ||
|
||
# Outer product of theta and position index | ||
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We won't be needing jit compatibility, only torch.compile, since we'll be doing python inference
|
||
Attributes: | ||
dim (int): Embedding dimension for each head, computed as: | ||
embed_size // num_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is num_heads here and where would it be specified, the attention block? Shall we clarify this as num_attention_heads? And does this value take on different meaning for GQA / MQA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeh thats a good point.
Attributes: | ||
dim (int): Embedding dimension for each head, computed as: | ||
embed_size // num_heads | ||
max_seq_len (int): Maximum expected sequence length for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add defaults in documentation?
return torch.randn(bsz, seq_len, num_heads, head_dim) | ||
|
||
@pytest.fixture | ||
def rope(self, input_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a state_dict
compatibility test might be useful too. For example, take a state_dict in memory and verifies that it has the expected keys, which will help us ensure correctness when we load in pretrained weights.
random.seed(seed) | ||
|
||
|
||
def assert_expected( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do you want to add some defaults for this for ease of use as we develop?
Remove 130 LOC from this recipe
Adds option to do torch profile tracing via: --run_profiler (T/F) --profile_folder (str) Traces are saved out with rank_X as part of the trace name. <img width="1711" alt="rank_named_traces" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/6eb3c3e0-6034-4d1f-8ea8-f43988755714"> Implemented as context wrapper around the main training loop.
Adding rotary positional embeddings and associated tests. I also abuse this commit by adding the gitignore file
note: This needs Pytest for the test. I'll add a dev-requirements file in a separate PR.
For testing:
cd ~/torch_tbd/tests
pytest