20
20
from llms_from_scratch .kv_cache_batched .generate import generate_text_simple as generate_text_simple_batched
21
21
22
22
import importlib
23
+ import platform
23
24
import pytest
24
25
import torch
25
26
import torch .nn as nn
@@ -107,6 +108,10 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
107
108
@pytest .mark .parametrize ("cfg_name" , ["dummy_cfg_base" , "dummy_cfg_moe" ])
108
109
def test_qwen3_kvcache_equivalence (cfg_name , request ):
109
110
cfg = request .getfixturevalue (cfg_name )
111
+
112
+ if cfg ["num_experts" ] > 0 and platform .system () == "Linux" :
113
+ pytest .skip ("Skipping MoE KV equivalence test on Linux due to nondeterministic expert routing" )
114
+
110
115
torch .manual_seed (123 )
111
116
model_regular = Qwen3Model (cfg )
112
117
model_regular .eval ()
@@ -130,13 +135,7 @@ def test_qwen3_kvcache_equivalence(cfg_name, request):
130
135
out_kv = torch .cat (logits_stepwise , dim = 1 )
131
136
132
137
assert out_full .shape == out_kv .shape , f"Shape mismatch: { out_full .shape } vs { out_kv .shape } "
133
-
134
- if cfg ["num_experts" ] > 0 :
135
- # MoE models are not bit-identical due to non-deterministic topk on Linux (works fine on macOS)
136
- cosine_sim = torch .nn .functional .cosine_similarity (out_full .flatten (), out_kv .flatten (), dim = 0 )
137
- assert cosine_sim > 0.99 , f"Low cosine similarity for MoE model: { cosine_sim .item ()} "
138
- else :
139
- assert torch .allclose (out_full , out_kv , atol = 1e-5 , rtol = 1e-3 )
138
+ assert torch .allclose (out_full , out_kv , atol = 1e-5 , rtol = 1e-3 )
140
139
141
140
142
141
@pytest .mark .skipif (not transformers_installed , reason = "transformers not installed" )
0 commit comments