Skip to content

Commit 9cf6417

Browse files
authored
Update Qwen3 tokenizer test (#727)
* Update Qwen3 tokenizer test * add tokenizers to dev dependencies * add tokenizers to dev dependencies
1 parent 83c7689 commit 9cf6417

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

.github/workflows/check-links.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
curl -LsSf https://astral.sh/uv/install.sh | sh
26+
uv sync --dev
2627
uv add pytest-ruff pytest-check-links
2728
2829
- name: Check links

pkg/llms_from_scratch/tests/test_qwen3.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import importlib
2020
import pytest
21-
import tiktoken
2221
import torch
2322
import torch.nn as nn
2423

@@ -102,8 +101,8 @@ class RoPEConfig:
102101

103102
@pytest.fixture(scope="session")
104103
def qwen3_weights_path(tmp_path_factory):
105-
"""Creates and saves a deterministic Llama3 model for testing."""
106-
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
104+
"""Creates and saves a deterministic model for testing."""
105+
path = tmp_path_factory.mktemp("models") / "qwen3_test_weights.pt"
107106

108107
if not path.exists():
109108
torch.manual_seed(123)
@@ -122,26 +121,33 @@ def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
122121
model.load_state_dict(torch.load(qwen3_weights_path))
123122
model.eval()
124123

125-
start_context = "Llamas eat"
124+
tokenizer = Qwen3Tokenizer(
125+
tokenizer_file_path="tokenizer-base.json",
126+
repo_id="rasbt/qwen3-from-scratch",
127+
add_generation_prompt=False,
128+
add_thinking=False
129+
)
126130

127-
tokenizer = tiktoken.get_encoding("gpt2")
128-
encoded = tokenizer.encode(start_context)
129-
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
131+
prompt = "Give me a short introduction to large language models."
132+
input_token_ids = tokenizer.encode(prompt)
133+
input_token_ids = torch.tensor([input_token_ids])
130134

131135
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
132-
print("\nInput text:", start_context)
133-
print("Encoded input text:", encoded)
134-
print("encoded_tensor.shape:", encoded_tensor.shape)
136+
print("\nInput text:", prompt)
137+
print("Encoded input text:", input_token_ids)
138+
print("encoded_tensor.shape:", input_token_ids.shape)
135139

136140
out = generate_text_simple(
137141
model=model,
138-
idx=encoded_tensor,
142+
idx=input_token_ids,
139143
max_new_tokens=5,
140144
context_size=QWEN_CONFIG_06_B["context_length"]
141145
)
142146
print("Encoded output text:", out)
143147
expect = torch.tensor([
144-
[43, 2543, 292, 4483, 115206, 459, 43010, 104223, 55553]
148+
[151644, 872, 198, 35127, 752, 264, 2805, 16800, 311,
149+
3460, 4128, 4119, 13, 151645, 198, 112120, 83942, 60483,
150+
102652, 7414]
145151
])
146152
assert torch.equal(expect, out)
147153

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dev = [
2929
"build>=1.2.2.post1",
3030
"llms-from-scratch",
3131
"twine>=6.1.0",
32+
"tokenizers>=0.21.1",
3233
]
3334

3435
[tool.ruff]

0 commit comments

Comments
 (0)