@@ -102,19 +102,21 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
102
102
"Expected MoEFeedForward in at least one transformer block"
103
103
104
104
105
- def test_qwen3_base_kvcache_equivalence (dummy_cfg_base ):
106
- model_regular = Qwen3Model (dummy_cfg_base )
105
+ @pytest .mark .parametrize ("cfg_name" , ["dummy_cfg_base" , "dummy_cfg_moe" ])
106
+ def test_qwen3_kvcache_equivalence (cfg_name , request ):
107
+ cfg = request .getfixturevalue (cfg_name )
108
+ model_regular = Qwen3Model (cfg )
107
109
model_regular .eval ()
108
110
109
- model_kv = Qwen3ModelKV (dummy_cfg_base )
111
+ model_kv = Qwen3ModelKV (cfg )
110
112
model_kv .eval ()
111
113
model_kv .load_state_dict (model_regular .state_dict ()) # ensure same weights
112
114
113
115
model_kv .reset_kv_cache ()
114
- cache = KVCache (n_layers = dummy_cfg_base ["n_layers" ])
116
+ cache = KVCache (n_layers = cfg ["n_layers" ])
115
117
116
118
torch .manual_seed (123 )
117
- input_ids = torch .randint (0 , dummy_cfg_base ["vocab_size" ], (1 , 6 )) # batch_size=1, seq_len=6
119
+ input_ids = torch .randint (0 , cfg ["vocab_size" ], (1 , 6 )) # batch_size=1, seq_len=6
118
120
119
121
# full-sequence output
120
122
out_full = model_regular (input_ids )
@@ -132,35 +134,6 @@ def test_qwen3_base_kvcache_equivalence(dummy_cfg_base):
132
134
assert torch .allclose (out_full , out_kv , atol = 1e-5 , rtol = 1e-3 )
133
135
134
136
135
- @pytest .mark .parametrize ("cfg_name" , ["dummy_cfg_base" , "dummy_cfg_moe" ])
136
- def test_qwen3_moe_kvcache_equivalence (cfg_name ):
137
- model_regular = Qwen3Model (cfg_name )
138
- model_regular .eval ()
139
-
140
- torch .manual_seed (123 )
141
- input_ids = torch .randint (0 , cfg_name ["vocab_size" ], (1 , 6 )) # batch_size=1, seq_len=6
142
-
143
- # No KV cache forward
144
- out_full = model_regular (input_ids )
145
-
146
- # Now with KV cache
147
- model_kv = Qwen3ModelKV (cfg_name )
148
- model_kv .eval ()
149
- model_kv .reset_kv_cache ()
150
- cache = KVCache (n_layers = cfg_name ["n_layers" ])
151
-
152
- logits_stepwise = []
153
- for t in range (input_ids .size (1 )):
154
- input_token = input_ids [:, t :t + 1 ] # shape (1, 1)
155
- logits = model_kv (input_token , cache = cache )
156
- logits_stepwise .append (logits )
157
-
158
- # Concatenate all stepwise outputs
159
- out_kv = torch .cat (logits_stepwise , dim = 1 )
160
-
161
- assert torch .allclose (out_full , out_kv , atol = 1e-5 , rtol = 1e-3 )
162
-
163
-
164
137
@pytest .mark .skipif (not transformers_installed , reason = "transformers not installed" )
165
138
def test_rope ():
166
139
0 commit comments