From 8947c2bb1af9cedd978b9de147e30ccdfe3bbe62 Mon Sep 17 00:00:00 2001 From: Kartikay Khandelwal Date: Tue, 24 Oct 2023 15:26:24 -0700 Subject: [PATCH 1/2] Add Rotary Positional Embeddings Adding rotary positional embeddings and associated tests. I also abuse this commit by adding the gitignore file --- .gitignore | 179 +++++++++++++++++++ llm/llama2/__init__.py | 5 + llm/llama2/position_embeddings.py | 85 +++++++++ tests/__init__.py | 5 + tests/llm/__init__.py | 5 + tests/llm/llama2/__init__.py | 5 + tests/llm/llama2/test_position_embeddings.py | 50 ++++++ tests/test_utils.py | 33 ++++ 8 files changed, 367 insertions(+) create mode 100644 .gitignore create mode 100644 llm/llama2/__init__.py create mode 100644 llm/llama2/position_embeddings.py create mode 100644 tests/__init__.py create mode 100644 tests/llm/__init__.py create mode 100644 tests/llm/llama2/__init__.py create mode 100644 tests/llm/llama2/test_position_embeddings.py create mode 100644 tests/test_utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..38ce8b4a9c --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +# Derived from basic .gitignore template for python projects: +# https://github.com/github/gitignore/blob/main/Python.gitignore +# Please maintain the alphabetic order of the section titles +# To debug why a file is being ignored, use the command: +# git check-ignore -v $my_ignored_file + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Cython debug symbols +cython_debug/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Django stuff +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Flask stuff +instance/ +.webassets-cache + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# IPython +profile_default/ +ipython_config.py + +# Jupyter Notebook +*.ipynb_checkpoints + +# mkdocs documentation +/site + +# Model saving / checkpointing +*.pt +*.pth +*.ckpt + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# PyBuilder +.pybuilder/ +target/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock + +# PEP 582: https://peps.python.org/pep-0582/ +# This PEP proposes to add to Python a mechanism to automatically recognize a __pypackages__ +# directory and prefer importing packages installed in this location over user or global site-packages. +# This will avoid the steps to create, activate or deactivate virtual environments. Python will use +# the __pypackages__ from the base directory of the script when present. +__pypackages__/ + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Rope project settings +.ropeproject + +# SageMath parsed files +*.sage.py + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# Spyder project settings +.spyderproject +.spyproject + +# System / program generated files +*.err +*.log +*.swp +.DS_Store + +# Translations +*.mo +*.pot + +# TorchX +*.torchxconfig + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# VSCode +.vscode/ diff --git a/llm/llama2/__init__.py b/llm/llama2/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/llm/llama2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/llm/llama2/position_embeddings.py b/llm/llama2/position_embeddings.py new file mode 100644 index 0000000000..ba8c254e97 --- /dev/null +++ b/llm/llama2/position_embeddings.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torch import nn + + +class RotaryPositionalEmbeddings(nn.Module): + """ + This class implements Rotary Positional Embeddings (RoPE) + proposed in: https://arxiv.org/abs/2104.09864 + + Reference implementation (used for correctness verfication) + can be found here: + https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450 + + Attributes: + dim (int): Embedding dimension for each head, computed as: + embed_size // num_heads + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + + Args: + x (tensor): input tensor to which rope is applied + + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + + theta = 1.0 / ( + base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + self.register_buffer("theta", theta) + self.build_rope_cache(max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + cache = torch.stack( + [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1 + ) + self.register_buffer("cache", cache) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + TODO: The implementation below can be made more efficient + for inference. + """ + seq_len = x.size(1) + rope_cache = self.cache[:seq_len] + + # cast because the reference does + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/llm/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/llm/llama2/__init__.py b/tests/llm/llama2/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/llm/llama2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/llm/llama2/test_position_embeddings.py b/tests/llm/llama2/test_position_embeddings.py new file mode 100644 index 0000000000..cf102b6ad9 --- /dev/null +++ b/tests/llm/llama2/test_position_embeddings.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from llm.llama2.position_embeddings import RotaryPositionalEmbeddings +from tests.test_utils import assert_expected, set_rng_seed + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestRotaryPositionEmbedding: + """ + Class for testing our Rotary Positional Embeddings (RoPE) + implementation. The expected tensors are computed from the + reference implementation here: + https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450 + """ + + @pytest.fixture + def input_params(self): + bsz = 4 + num_heads = 32 + embed_dim = 4096 + head_dim = embed_dim // num_heads + seq_len = 4096 + return bsz, num_heads, head_dim, seq_len + + @pytest.fixture + def input(self, input_params): + bsz, num_heads, head_dim, seq_len = input_params + return torch.randn(bsz, seq_len, num_heads, head_dim) + + @pytest.fixture + def rope(self, input_params): + _, _, head_dim, seq_len = input_params + return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=seq_len) + + def test_forward(self, input, rope): + x_out = rope(input) + assert_expected(x_out.mean(), torch.tensor(-4.3060e-05), rtol=1e-05, atol=1e-8) + assert_expected(x_out.sum(), torch.tensor(-2889.6804), rtol=1e-05, atol=1e-8) + assert_expected(x_out.max(), torch.tensor(5.6446), rtol=1e-05, atol=1e-8) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..e8307abfe1 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import torch +from typing import Any, Optional + + +def set_rng_seed(seed): + """Sets the seed for random number generators""" + torch.manual_seed(seed) + random.seed(seed) + + +def assert_expected( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + check_device=True, +): + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + check_device=check_device, + msg=f"actual: {actual}, expected: {expected}", + ) From 127fe1e143af84459e54502ddcb25f98eea1b73e Mon Sep 17 00:00:00 2001 From: Kartikay Khandelwal Date: Wed, 25 Oct 2023 11:09:29 -0700 Subject: [PATCH 2/2] Address review comments --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e8307abfe1..876c168bb5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,8 +19,8 @@ def set_rng_seed(seed): def assert_expected( actual: Any, expected: Any, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float = 1e-5, + atol: float = 1e-8, check_device=True, ): torch.testing.assert_close(