Skip to content

Commit 12169a3

Browse files
committed
updates
1 parent 55d66e2 commit 12169a3

File tree

2 files changed

+8
-35
lines changed

2 files changed

+8
-35
lines changed

.github/workflows/basic-tests-windows-uv-pip.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
shell: bash
3636
run: |
3737
export PATH="$HOME/.local/bin:$PATH"
38-
pip install --upgrade pip
38+
python -m pip install --upgrade pip
3939
pip install uv
4040
uv venv --python=python3.11
4141
source .venv/Scripts/activate

pkg/llms_from_scratch/tests/test_qwen3.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,21 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
102102
"Expected MoEFeedForward in at least one transformer block"
103103

104104

105-
def test_qwen3_base_kvcache_equivalence(dummy_cfg_base):
106-
model_regular = Qwen3Model(dummy_cfg_base)
105+
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
106+
def test_qwen3_kvcache_equivalence(cfg_name, request):
107+
cfg = request.getfixturevalue(cfg_name)
108+
model_regular = Qwen3Model(cfg)
107109
model_regular.eval()
108110

109-
model_kv = Qwen3ModelKV(dummy_cfg_base)
111+
model_kv = Qwen3ModelKV(cfg)
110112
model_kv.eval()
111113
model_kv.load_state_dict(model_regular.state_dict()) # ensure same weights
112114

113115
model_kv.reset_kv_cache()
114-
cache = KVCache(n_layers=dummy_cfg_base["n_layers"])
116+
cache = KVCache(n_layers=cfg["n_layers"])
115117

116118
torch.manual_seed(123)
117-
input_ids = torch.randint(0, dummy_cfg_base["vocab_size"], (1, 6)) # batch_size=1, seq_len=6
119+
input_ids = torch.randint(0, cfg["vocab_size"], (1, 6)) # batch_size=1, seq_len=6
118120

119121
# full-sequence output
120122
out_full = model_regular(input_ids)
@@ -132,35 +134,6 @@ def test_qwen3_base_kvcache_equivalence(dummy_cfg_base):
132134
assert torch.allclose(out_full, out_kv, atol=1e-5, rtol=1e-3)
133135

134136

135-
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
136-
def test_qwen3_moe_kvcache_equivalence(cfg_name):
137-
model_regular = Qwen3Model(cfg_name)
138-
model_regular.eval()
139-
140-
torch.manual_seed(123)
141-
input_ids = torch.randint(0, cfg_name["vocab_size"], (1, 6)) # batch_size=1, seq_len=6
142-
143-
# No KV cache forward
144-
out_full = model_regular(input_ids)
145-
146-
# Now with KV cache
147-
model_kv = Qwen3ModelKV(cfg_name)
148-
model_kv.eval()
149-
model_kv.reset_kv_cache()
150-
cache = KVCache(n_layers=cfg_name["n_layers"])
151-
152-
logits_stepwise = []
153-
for t in range(input_ids.size(1)):
154-
input_token = input_ids[:, t:t+1] # shape (1, 1)
155-
logits = model_kv(input_token, cache=cache)
156-
logits_stepwise.append(logits)
157-
158-
# Concatenate all stepwise outputs
159-
out_kv = torch.cat(logits_stepwise, dim=1)
160-
161-
assert torch.allclose(out_full, out_kv, atol=1e-5, rtol=1e-3)
162-
163-
164137
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
165138
def test_rope():
166139

0 commit comments

Comments
 (0)