Skip to content

Commit 78e254c

Browse files
committed
fix rope unit tests
1 parent 0f77094 commit 78e254c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tests/torchtune/models/llama3_1/test_position_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_rope_init_meta_device(self, input_params):
132132
with torch.device("meta"):
133133
meta_rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)
134134

135-
meta_rope._rope_init()
135+
meta_rope.rope_init()
136136
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
137137
torch.testing.assert_close(p1, p2)
138138

tests/torchtune/modules/test_position_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_rope_init_meta_device(self, input_params):
128128
dim=head_dim, max_seq_len=max_seq_len
129129
)
130130

131-
meta_rope._rope_init()
131+
meta_rope.rope_init()
132132
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
133133
torch.testing.assert_close(p1, p2)
134134

0 commit comments

Comments
 (0)