@@ -71,7 +71,7 @@ def dummy_cfg_base():
71
71
"n_kv_groups" : 1 ,
72
72
"qk_norm" : False ,
73
73
"dtype" : torch .float32 ,
74
- "rope_base" : 10000 ,
74
+ "rope_base" : 1000000 ,
75
75
"context_length" : 64 ,
76
76
"num_experts" : 0 ,
77
77
}
@@ -143,18 +143,21 @@ def test_qwen3_kvcache_equivalence(cfg_name, request):
143
143
144
144
145
145
@pytest .mark .skipif (not transformers_installed , reason = "transformers not installed" )
146
- def test_rope ():
146
+ @pytest .mark .parametrize ("context_len" , [1024 , 8192 , 40960 ])
147
+ def test_rope (context_len ):
147
148
148
- from transformers .models .qwen3 .modeling_qwen3 import Qwen3RotaryEmbedding , apply_rotary_pos_emb
149
+ from transformers .models .qwen3 .modeling_qwen3 import (
150
+ Qwen3RotaryEmbedding ,
151
+ apply_rotary_pos_emb ,
152
+ )
149
153
150
154
# Settings
151
155
batch_size = 1
152
- context_len = 8192
153
156
num_heads = 4
154
157
head_dim = 16
155
158
rope_theta = 1_000_000
156
159
157
- # Instantiate RoPE parameters
160
+ # Instantiate RoPE parameters (our implementation)
158
161
cos , sin = compute_rope_params (
159
162
head_dim = head_dim ,
160
163
theta_base = rope_theta ,
@@ -166,7 +169,7 @@ def test_rope():
166
169
queries = torch .randn (batch_size , num_heads , context_len , head_dim )
167
170
keys = torch .randn (batch_size , num_heads , context_len , head_dim )
168
171
169
- # Apply rotary position embeddings
172
+ # Apply rotary embeddings with our implementation
170
173
queries_rot = apply_rope (queries , cos , sin )
171
174
keys_rot = apply_rope (keys , cos , sin )
172
175
@@ -176,7 +179,7 @@ class RoPEConfig:
176
179
factor = 1.0
177
180
dim : int = head_dim
178
181
rope_theta = 1_000_000
179
- max_position_embeddings : int = 8192
182
+ max_position_embeddings = context_len
180
183
hidden_size = head_dim * num_heads
181
184
num_attention_heads = num_heads
182
185
@@ -187,10 +190,17 @@ class RoPEConfig:
187
190
ref_cos , ref_sin = rot_emb (queries , position_ids )
188
191
ref_queries_rot , ref_keys_rot = apply_rotary_pos_emb (queries , keys , ref_cos , ref_sin )
189
192
190
- torch .testing .assert_close (sin , ref_sin .squeeze (0 ))
191
- torch .testing .assert_close (cos , ref_cos .squeeze (0 ))
192
- torch .testing .assert_close (keys_rot , ref_keys_rot )
193
- torch .testing .assert_close (queries_rot , ref_queries_rot )
193
+ # torch.testing.assert_close(sin, ref_sin.squeeze(0), rtol=1e-5, atol=1e-6)
194
+ # torch.testing.assert_close(cos, ref_cos.squeeze(0), rtol=1e-5, atol=1e-6)
195
+
196
+ # torch.testing.assert_close(keys_rot, ref_keys_rot, rtol=1e-5, atol=1e-6)A
197
+ # torch.testing.assert_close(queries_rot, ref_queries_rot, rtol=1e-5, atol=1e-6)
198
+
199
+ assert torch .equal (sin , ref_sin .squeeze (0 ))
200
+ assert torch .equal (cos , ref_cos .squeeze (0 ))
201
+
202
+ assert torch .equal (keys_rot , ref_keys_rot )
203
+ assert torch .equal (queries_rot , ref_queries_rot )
194
204
195
205
196
206
@pytest .fixture (scope = "session" )
0 commit comments