-
Notifications
You must be signed in to change notification settings - Fork 668
Add Rotary Positional Embeddings #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Derived from basic .gitignore template for python projects: | ||
# https://github.com/github/gitignore/blob/main/Python.gitignore | ||
# Please maintain the alphabetic order of the section titles | ||
# To debug why a file is being ignored, use the command: | ||
# git check-ignore -v $my_ignored_file | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# Django stuff | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Flask stuff | ||
instance/ | ||
.webassets-cache | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# Jupyter Notebook | ||
*.ipynb_checkpoints | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# Model saving / checkpointing | ||
*.pt | ||
*.pth | ||
*.ckpt | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
# Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
# poetry.lock | ||
|
||
# PEP 582: https://peps.python.org/pep-0582/ | ||
# This PEP proposes to add to Python a mechanism to automatically recognize a __pypackages__ | ||
# directory and prefer importing packages installed in this location over user or global site-packages. | ||
# This will avoid the steps to create, activate or deactivate virtual environments. Python will use | ||
# the __pypackages__ from the base directory of the script when present. | ||
__pypackages__/ | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# System / program generated files | ||
*.err | ||
*.log | ||
*.swp | ||
.DS_Store | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# TorchX | ||
*.torchxconfig | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# VSCode | ||
.vscode/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, maybe make a subdirectory There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't quite follow - can you elaborate? |
||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
from torch import nn | ||
|
||
|
||
class RotaryPositionalEmbeddings(nn.Module): | ||
""" | ||
This class implements Rotary Positional Embeddings (RoPE) | ||
proposed in: https://arxiv.org/abs/2104.09864 | ||
|
||
Reference implementation (used for correctness verfication) | ||
can be found here: | ||
https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450 | ||
|
||
Attributes: | ||
dim (int): Embedding dimension for each head, computed as: | ||
embed_size // num_heads | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is num_heads here and where would it be specified, the attention block? Shall we clarify this as num_attention_heads? And does this value take on different meaning for GQA / MQA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeh thats a good point. |
||
max_seq_len (int): Maximum expected sequence length for the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add defaults in documentation? |
||
model, if exceeded the cached freqs will be recomputed | ||
base (int): The base for the geometric progression used to compute | ||
the rotation angles | ||
|
||
Args: | ||
x (tensor): input tensor to which rope is applied | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
max_seq_len: int = 4096, | ||
base: int = 10_000, | ||
) -> None: | ||
super().__init__() | ||
self.dim = dim | ||
|
||
theta = 1.0 / ( | ||
base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) | ||
) | ||
self.register_buffer("theta", theta) | ||
self.build_rope_cache(max_seq_len) | ||
|
||
def build_rope_cache(self, max_seq_len: int = 4096) -> None: | ||
# Create position indexes `[0, 1, ..., max_seq_len - 1]` | ||
seq_idx = torch.arange( | ||
max_seq_len, dtype=self.theta.dtype, device=self.theta.device | ||
) | ||
|
||
# Outer product of theta and position index | ||
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, does this work with jit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something we have to worry about still? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We won't be needing jit compatibility, only torch.compile, since we'll be doing python inference |
||
|
||
cache = torch.stack( | ||
[torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1 | ||
) | ||
self.register_buffer("cache", cache) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
TODO: The implementation below can be made more efficient | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any pointers as to how so? |
||
for inference. | ||
""" | ||
seq_len = x.size(1) | ||
rope_cache = self.cache[:seq_len] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where does the cache actually get invalidated if we exceed the seq_len and recomputed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We compute this with max_seq_len that the model supports and so in the current setting it wouldn't need to be invalidated. There are some corner cases for inference which I don't think I fully understand right now |
||
|
||
# cast because the reference does | ||
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) | ||
rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) | ||
x_out2 = torch.stack( | ||
[ | ||
xshaped[..., 0] * rope_cache[..., 0] | ||
- xshaped[..., 1] * rope_cache[..., 1], | ||
xshaped[..., 1] * rope_cache[..., 0] | ||
+ xshaped[..., 0] * rope_cache[..., 1], | ||
], | ||
-1, | ||
) | ||
|
||
x_out2 = x_out2.flatten(3) | ||
return x_out2.type_as(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
|
||
import torch | ||
from llm.llama2.position_embeddings import RotaryPositionalEmbeddings | ||
from tests.test_utils import assert_expected, set_rng_seed | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_rng_seed(0) | ||
|
||
|
||
class TestRotaryPositionEmbedding: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you want to test the cache invalidation / recomputation? |
||
""" | ||
Class for testing our Rotary Positional Embeddings (RoPE) | ||
implementation. The expected tensors are computed from the | ||
reference implementation here: | ||
https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450 | ||
""" | ||
|
||
@pytest.fixture | ||
def input_params(self): | ||
bsz = 4 | ||
num_heads = 32 | ||
embed_dim = 4096 | ||
head_dim = embed_dim // num_heads | ||
seq_len = 4096 | ||
return bsz, num_heads, head_dim, seq_len | ||
|
||
@pytest.fixture | ||
def input(self, input_params): | ||
bsz, num_heads, head_dim, seq_len = input_params | ||
return torch.randn(bsz, seq_len, num_heads, head_dim) | ||
|
||
@pytest.fixture | ||
def rope(self, input_params): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a |
||
_, _, head_dim, seq_len = input_params | ||
return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=seq_len) | ||
|
||
def test_forward(self, input, rope): | ||
x_out = rope(input) | ||
assert_expected(x_out.mean(), torch.tensor(-4.3060e-05), rtol=1e-05, atol=1e-8) | ||
assert_expected(x_out.sum(), torch.tensor(-2889.6804), rtol=1e-05, atol=1e-8) | ||
assert_expected(x_out.max(), torch.tensor(5.6446), rtol=1e-05, atol=1e-8) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import random | ||
|
||
import torch | ||
from typing import Any, Optional | ||
|
||
|
||
def set_rng_seed(seed): | ||
"""Sets the seed for random number generators""" | ||
torch.manual_seed(seed) | ||
random.seed(seed) | ||
|
||
|
||
def assert_expected( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do you want to add some defaults for this for ease of use as we develop? |
||
actual: Any, | ||
expected: Any, | ||
rtol: float = 1e-5, | ||
atol: float = 1e-8, | ||
check_device=True, | ||
): | ||
torch.testing.assert_close( | ||
actual, | ||
expected, | ||
rtol=rtol, | ||
atol=atol, | ||
check_device=check_device, | ||
msg=f"actual: {actual}, expected: {expected}", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we remove these copyrights in OSS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually took this from the OSS code in Multimodal