Skip to content

Add Rotary Positional Embeddings #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 25, 2023
Merged

Conversation

kartikayk
Copy link
Contributor

Adding rotary positional embeddings and associated tests. I also abuse this commit by adding the gitignore file

note: This needs Pytest for the test. I'll add a dev-requirements file in a separate PR.

For testing:

cd ~/torch_tbd/tests
pytest

Adding rotary positional embeddings and associated tests. I also abuse
this commit by adding the gitignore file
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 25, 2023
)

# Outer product of theta and position index
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does this work with jit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we have to worry about still?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't be needing jit compatibility, only torch.compile, since we'll be doing python inference

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, most comments are minor, my main question is around recomputation of the RoPE cache and how that works.

Thanks!

@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we remove these copyrights in OSS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually took this from the OSS code in Multimodal

@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, maybe make a subdirectory /components for these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't quite follow - can you elaborate?


def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
TODO: The implementation below can be made more efficient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any pointers as to how so?

set_rng_seed(0)


class TestRotaryPositionEmbedding:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to test the cache invalidation / recomputation?

for inference.
"""
seq_len = x.size(1)
rope_cache = self.cache[:seq_len]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the cache actually get invalidated if we exceed the seq_len and recomputed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We compute this with max_seq_len that the model supports and so in the current setting it wouldn't need to be invalidated. There are some corner cases for inference which I don't think I fully understand right now

)

# Outer product of theta and position index
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't be needing jit compatibility, only torch.compile, since we'll be doing python inference


Attributes:
dim (int): Embedding dimension for each head, computed as:
embed_size // num_heads
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is num_heads here and where would it be specified, the attention block? Shall we clarify this as num_attention_heads? And does this value take on different meaning for GQA / MQA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeh thats a good point.

Attributes:
dim (int): Embedding dimension for each head, computed as:
embed_size // num_heads
max_seq_len (int): Maximum expected sequence length for the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add defaults in documentation?

return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope(self, input_params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a state_dict compatibility test might be useful too. For example, take a state_dict in memory and verifies that it has the expected keys, which will help us ensure correctness when we load in pretrained weights.

random.seed(seed)


def assert_expected(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you want to add some defaults for this for ease of use as we develop?

@kartikayk kartikayk merged commit 902ce2b into main Oct 25, 2023
@kartikayk kartikayk deleted the llama2_componentization branch October 25, 2023 18:13
janeyx99 referenced this pull request in janeyx99/torchtune May 16, 2024
@RdoubleA RdoubleA mentioned this pull request Aug 4, 2024
SLR722 added a commit to calvinpelletier/torchtune that referenced this pull request Aug 23, 2024
SLR722 added a commit to calvinpelletier/torchtune that referenced this pull request Aug 30, 2024
joecummings added a commit that referenced this pull request May 1, 2025
Remove 130 LOC from this recipe
FlamingoPg referenced this pull request in FlamingoPg/sgl-tune-eagle May 26, 2025
Adds option to do torch profile tracing via:
--run_profiler  (T/F)
--profile_folder (str) 

Traces are saved out with rank_X as part of the trace name.
<img width="1711" alt="rank_named_traces"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/6eb3c3e0-6034-4d1f-8ea8-f43988755714">

Implemented as context wrapper around the main training loop.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants