Skip to content

Add Rotary Positional Embeddings #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Derived from basic .gitignore template for python projects:
# https://github.com/github/gitignore/blob/main/Python.gitignore
# Please maintain the alphabetic order of the section titles
# To debug why a file is being ignored, use the command:
# git check-ignore -v $my_ignored_file

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Cython debug symbols
cython_debug/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Django stuff
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Flask stuff
instance/
.webassets-cache

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# IPython
profile_default/
ipython_config.py

# Jupyter Notebook
*.ipynb_checkpoints

# mkdocs documentation
/site

# Model saving / checkpointing
*.pt
*.pth
*.ckpt

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# PyBuilder
.pybuilder/
target/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock

# PEP 582: https://peps.python.org/pep-0582/
# This PEP proposes to add to Python a mechanism to automatically recognize a __pypackages__
# directory and prefer importing packages installed in this location over user or global site-packages.
# This will avoid the steps to create, activate or deactivate virtual environments. Python will use
# the __pypackages__ from the base directory of the script when present.
__pypackages__/

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Rope project settings
.ropeproject

# SageMath parsed files
*.sage.py

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# Spyder project settings
.spyderproject
.spyproject

# System / program generated files
*.err
*.log
*.swp
.DS_Store

# Translations
*.mo
*.pot

# TorchX
*.torchxconfig

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# VSCode
.vscode/
5 changes: 5 additions & 0 deletions llm/llama2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
85 changes: 85 additions & 0 deletions llm/llama2/position_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we remove these copyrights in OSS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually took this from the OSS code in Multimodal

# All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, maybe make a subdirectory /components for these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't quite follow - can you elaborate?

#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torch import nn


class RotaryPositionalEmbeddings(nn.Module):
"""
This class implements Rotary Positional Embeddings (RoPE)
proposed in: https://arxiv.org/abs/2104.09864

Reference implementation (used for correctness verfication)
can be found here:
https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450

Attributes:
dim (int): Embedding dimension for each head, computed as:
embed_size // num_heads
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is num_heads here and where would it be specified, the attention block? Shall we clarify this as num_attention_heads? And does this value take on different meaning for GQA / MQA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeh thats a good point.

max_seq_len (int): Maximum expected sequence length for the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add defaults in documentation?

model, if exceeded the cached freqs will be recomputed
base (int): The base for the geometric progression used to compute
the rotation angles

Args:
x (tensor): input tensor to which rope is applied

"""

def __init__(
self,
dim: int,
max_seq_len: int = 4096,
base: int = 10_000,
) -> None:
super().__init__()
self.dim = dim

theta = 1.0 / (
base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
self.register_buffer("theta", theta)
self.build_rope_cache(max_seq_len)

def build_rope_cache(self, max_seq_len: int = 4096) -> None:
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
seq_idx = torch.arange(
max_seq_len, dtype=self.theta.dtype, device=self.theta.device
)

# Outer product of theta and position index
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does this work with jit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we have to worry about still?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't be needing jit compatibility, only torch.compile, since we'll be doing python inference


cache = torch.stack(
[torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1
)
self.register_buffer("cache", cache)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
TODO: The implementation below can be made more efficient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any pointers as to how so?

for inference.
"""
seq_len = x.size(1)
rope_cache = self.cache[:seq_len]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the cache actually get invalidated if we exceed the seq_len and recomputed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We compute this with max_seq_len that the model supports and so in the current setting it wouldn't need to be invalidated. There are some corner cases for inference which I don't think I fully understand right now


# cast because the reference does
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0]
- xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0]
+ xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions tests/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions tests/llm/llama2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
50 changes: 50 additions & 0 deletions tests/llm/llama2/test_position_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from llm.llama2.position_embeddings import RotaryPositionalEmbeddings
from tests.test_utils import assert_expected, set_rng_seed


@pytest.fixture(autouse=True)
def random():
set_rng_seed(0)


class TestRotaryPositionEmbedding:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to test the cache invalidation / recomputation?

"""
Class for testing our Rotary Positional Embeddings (RoPE)
implementation. The expected tensors are computed from the
reference implementation here:
https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450
"""

@pytest.fixture
def input_params(self):
bsz = 4
num_heads = 32
embed_dim = 4096
head_dim = embed_dim // num_heads
seq_len = 4096
return bsz, num_heads, head_dim, seq_len

@pytest.fixture
def input(self, input_params):
bsz, num_heads, head_dim, seq_len = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope(self, input_params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a state_dict compatibility test might be useful too. For example, take a state_dict in memory and verifies that it has the expected keys, which will help us ensure correctness when we load in pretrained weights.

_, _, head_dim, seq_len = input_params
return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=seq_len)

def test_forward(self, input, rope):
x_out = rope(input)
assert_expected(x_out.mean(), torch.tensor(-4.3060e-05), rtol=1e-05, atol=1e-8)
assert_expected(x_out.sum(), torch.tensor(-2889.6804), rtol=1e-05, atol=1e-8)
assert_expected(x_out.max(), torch.tensor(5.6446), rtol=1e-05, atol=1e-8)
33 changes: 33 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import random

import torch
from typing import Any, Optional


def set_rng_seed(seed):
"""Sets the seed for random number generators"""
torch.manual_seed(seed)
random.seed(seed)


def assert_expected(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you want to add some defaults for this for ease of use as we develop?

actual: Any,
expected: Any,
rtol: float = 1e-5,
atol: float = 1e-8,
check_device=True,
):
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
check_device=check_device,
msg=f"actual: {actual}, expected: {expected}",
)