Skip to content

Commit 70edd53

Browse files
authored
Improve RoPE (#799)
1 parent d87d91b commit 70edd53

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

pkg/llms_from_scratch/tests/test_qwen3.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def dummy_cfg_base():
7171
"n_kv_groups": 1,
7272
"qk_norm": False,
7373
"dtype": torch.float32,
74-
"rope_base": 10000,
74+
"rope_base": 1000000,
7575
"context_length": 64,
7676
"num_experts": 0,
7777
}
@@ -143,18 +143,21 @@ def test_qwen3_kvcache_equivalence(cfg_name, request):
143143

144144

145145
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
146-
def test_rope():
146+
@pytest.mark.parametrize("context_len", [1024, 8192, 40960])
147+
def test_rope(context_len):
147148

148-
from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
149+
from transformers.models.qwen3.modeling_qwen3 import (
150+
Qwen3RotaryEmbedding,
151+
apply_rotary_pos_emb,
152+
)
149153

150154
# Settings
151155
batch_size = 1
152-
context_len = 8192
153156
num_heads = 4
154157
head_dim = 16
155158
rope_theta = 1_000_000
156159

157-
# Instantiate RoPE parameters
160+
# Instantiate RoPE parameters (our implementation)
158161
cos, sin = compute_rope_params(
159162
head_dim=head_dim,
160163
theta_base=rope_theta,
@@ -166,7 +169,7 @@ def test_rope():
166169
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
167170
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
168171

169-
# Apply rotary position embeddings
172+
# Apply rotary embeddings with our implementation
170173
queries_rot = apply_rope(queries, cos, sin)
171174
keys_rot = apply_rope(keys, cos, sin)
172175

@@ -176,7 +179,7 @@ class RoPEConfig:
176179
factor = 1.0
177180
dim: int = head_dim
178181
rope_theta = 1_000_000
179-
max_position_embeddings: int = 8192
182+
max_position_embeddings = context_len
180183
hidden_size = head_dim * num_heads
181184
num_attention_heads = num_heads
182185

@@ -187,10 +190,17 @@ class RoPEConfig:
187190
ref_cos, ref_sin = rot_emb(queries, position_ids)
188191
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
189192

190-
torch.testing.assert_close(sin, ref_sin.squeeze(0))
191-
torch.testing.assert_close(cos, ref_cos.squeeze(0))
192-
torch.testing.assert_close(keys_rot, ref_keys_rot)
193-
torch.testing.assert_close(queries_rot, ref_queries_rot)
193+
# torch.testing.assert_close(sin, ref_sin.squeeze(0), rtol=1e-5, atol=1e-6)
194+
# torch.testing.assert_close(cos, ref_cos.squeeze(0), rtol=1e-5, atol=1e-6)
195+
196+
# torch.testing.assert_close(keys_rot, ref_keys_rot, rtol=1e-5, atol=1e-6)A
197+
# torch.testing.assert_close(queries_rot, ref_queries_rot, rtol=1e-5, atol=1e-6)
198+
199+
assert torch.equal(sin, ref_sin.squeeze(0))
200+
assert torch.equal(cos, ref_cos.squeeze(0))
201+
202+
assert torch.equal(keys_rot, ref_keys_rot)
203+
assert torch.equal(queries_rot, ref_queries_rot)
194204

195205

196206
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)