Skip to content

Commit 83557aa

Browse files
committed
Add test for Llama3.1 ScaledRoPE (#1216)
1 parent 403c7f3 commit 83557aa

File tree

3 files changed

+146
-12
lines changed

3 files changed

+146
-12
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from tests.test_utils import assert_expected
11+
from torch import tensor
12+
13+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
14+
15+
from torchtune.utils.seed import set_seed
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def random():
20+
set_seed(0)
21+
22+
23+
class TestLlama3ScaledRoPE:
24+
"""
25+
Class for testing our Scaled RoPE for LLama3.1 (RoPE)
26+
implementation. The expected tensors are computed from the
27+
reference implementation here:
28+
https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L272
29+
30+
The expected values are computed using the following code:
31+
https://gist.github.com/joecummings/4f1331a9c1e5aa15bad1641acb74fe0e
32+
"""
33+
34+
EXPECTED_FREQS_CIS_MEAN = tensor(0.1738)
35+
EXPECTED_FREQS_CIS_SUM = tensor(91141.7656)
36+
EXPECTED_FREQS_CIS_MAX = tensor(1.0)
37+
38+
EXPECTED_X_OUT_MEAN = tensor(-2.4781e-06)
39+
EXPECTED_X_OUT_SUM = tensor(-83.1523)
40+
EXPECTED_X_OUT_MAX = tensor(5.4625)
41+
42+
@pytest.fixture
43+
def input_params(self):
44+
bsz = 4
45+
num_heads = 32
46+
embed_dim = 4096
47+
head_dim = embed_dim // num_heads
48+
seq_len = 2048
49+
max_seq_len = 4096
50+
return bsz, num_heads, head_dim, seq_len, max_seq_len
51+
52+
@pytest.fixture
53+
def input(self, input_params) -> tensor:
54+
bsz, num_heads, head_dim, seq_len, _ = input_params
55+
return torch.randn(bsz, seq_len, num_heads, head_dim)
56+
57+
@pytest.fixture
58+
def rope(self, input_params) -> Llama3ScaledRoPE:
59+
_, _, head_dim, _, max_seq_len = input_params
60+
return Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)
61+
62+
def test_cache_equality(self, input, rope) -> None:
63+
# Have to explicitly call _rope_init() to initialize theta matrix
64+
rope._rope_init()
65+
cache = rope.cache
66+
67+
assert_expected(cache.mean(), self.EXPECTED_FREQS_CIS_MEAN, atol=1e-4)
68+
assert_expected(cache.sum(), self.EXPECTED_FREQS_CIS_SUM, atol=1e-4)
69+
assert_expected(cache.max(), self.EXPECTED_FREQS_CIS_MAX)
70+
71+
def test_forward(self, input, rope) -> None:
72+
x_out = rope(input)
73+
74+
# check the numerics of the computed tensor
75+
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN)
76+
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
77+
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)
78+
79+
# check shapes
80+
assert_expected(x_out.shape, input.shape)
81+
82+
def test_forward_with_curr_pos(self, input, rope) -> None:
83+
(
84+
_,
85+
seq_len,
86+
_,
87+
_,
88+
) = input.shape
89+
x_out = rope(input, input_pos=torch.arange(seq_len))
90+
91+
# these values should be exactly the same as test_forward
92+
# since in this case input_pos covers the entire input
93+
# sequence. This tests that input_pos works as expected i.e.
94+
# extracts the embeddings for the relevant positions
95+
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
96+
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
97+
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)
98+
99+
# check shapes
100+
assert_expected(x_out.shape, input.shape)
101+
102+
def test_forward_with_2d_pos_ids(self, input, rope) -> None:
103+
"""
104+
Use input_pos to indicate positions of each token relative to its sequence
105+
when sample is packed.
106+
"""
107+
(
108+
bsz,
109+
seq_len,
110+
_,
111+
_,
112+
) = input.shape
113+
x_out = rope(
114+
input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len)
115+
)
116+
117+
# these values should be exactly the same as test_forward
118+
# AND test_forward_with_current_pos. In this case input_pos
119+
# covers the entire batch dim and is defined for each sample separately.
120+
# This tests that input_pos works as expected i.e.
121+
# extracts the embeddings for the relevant positions for each sample
122+
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
123+
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
124+
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)
125+
126+
# check shapes
127+
assert_expected(x_out.shape, input.shape)
128+
129+
def test_rope_init_meta_device(self, input_params):
130+
_, _, head_dim, _, max_seq_len = input_params
131+
rope_on_device = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)
132+
with torch.device("meta"):
133+
meta_rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)
134+
135+
meta_rope._rope_init()
136+
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
137+
torch.testing.assert_close(p1, p2)
138+
139+
# Assert meta_rope cache is no longer on meta device
140+
assert meta_rope.cache.device != torch.device("meta")

torchtune/models/llama3_1/_component_builders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111

1212
from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp
13-
from torchtune.models.llama3_1._position_embeddings import Llama31RotaryPositionalEmbeddings
13+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1414

1515
from torchtune.modules import (
1616
CausalSelfAttention,
@@ -81,7 +81,7 @@ def llama3_1(
8181
"""
8282
head_dim = embed_dim // num_heads
8383
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
84-
rope = Llama31RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
84+
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
8585
self_attn = CausalSelfAttention(
8686
embed_dim=embed_dim,
8787
num_heads=num_heads,
@@ -358,7 +358,7 @@ def lora_llama3_1_self_attention(
358358
if "output_proj" in lora_modules
359359
else nn.Linear(embed_dim, embed_dim, bias=False)
360360
)
361-
rope = Llama31RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
361+
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
362362
self_attn = CausalSelfAttention(
363363
embed_dim=embed_dim,
364364
num_heads=num_heads,

torchtune/models/llama3_1/_position_embeddings.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212
from torch import nn, Tensor
1313

1414

15-
class Llama31RotaryPositionalEmbeddings(nn.Module):
15+
class Llama3ScaledRoPE(nn.Module):
1616
"""
1717
This class implements Rotary Positional Embeddings (RoPE)
18-
proposed in https://arxiv.org/abs/2104.09864.
19-
20-
Reference implementation (used for correctness verfication)
21-
can be found here:
22-
https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
18+
proposed in https://arxiv.org/abs/2104.09864 with additional
19+
scaling from https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1.
2320
2421
In this implementation we cache the embeddings for each position upto
2522
``max_seq_len`` by computing this during init.
@@ -120,9 +117,6 @@ def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor:
120117
- s: sequence length
121118
- n_h: num heads
122119
- h_d: head dim
123-
124-
TODO: The implementation below can be made more efficient
125-
for inference.
126120
"""
127121
# TODO: Remove this hack for handling scaling for Meta device
128122
if not self.is_cache_built:

0 commit comments

Comments
 (0)