Skip to content

Commit 55d66e2

Browse files
committed
refinements
1 parent af84d7b commit 55d66e2

File tree

7 files changed

+199
-27
lines changed

7 files changed

+199
-27
lines changed

ch05/11_qwen3/README.md

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# Qwen3 From Scratch
22

3-
This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Qwen3 0.6B, 1.7B, 4B, 8B, and 32 B.
3+
This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Qwen3 0.6B, 1.7B, 4B, 8B, and 32B.
44

55
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen-overview.webp">
66

77

8+
This [standalone-qwen3-moe.ipynb](standalone-qwen3-moe.ipynb) and [standalone-qwen3-moe-plus-kvcache.ipynb](standalone-qwen3-moe-plus-kvcache.ipynb) Jupyter notebooks in this folder contain a from-scratch implementation of 30B-A3B Mixture-of-Experts (MoE), including the Thinking, Instruct, and Coder model variants.
9+
10+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen3-coder-flash-overview.webp?123" width="430px">
11+
12+
13+
814
&nbsp;
9-
### Using Qwen3 via the `llms-from-scratch` package
15+
# Using Qwen3 via the `llms-from-scratch` package
1016

1117
For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
1218

@@ -23,8 +29,9 @@ pip install llms_from_scratch tokenizers
2329
Specify which model to use:
2430

2531
```python
26-
USE_REASONING_MODEL = True # The "thinking" model
2732
USE_REASONING_MODEL = False # The base model
33+
USE_REASONING_MODEL = True # The "thinking" model
34+
2835

2936
# Use
3037
# USE_REASONING_MODEL = True
@@ -130,22 +137,22 @@ from llms_from_scratch.qwen3 import (
130137
load_weights_into_qwen
131138
)
132139

133-
model = Qwen3Model(QWEN3_CONFIG)
140+
device = (
141+
torch.device("cuda") if torch.cuda.is_available() else
142+
torch.device("mps") if torch.backends.mps.is_available() else
143+
torch.device("cpu")
144+
)
145+
146+
with device:
147+
model = Qwen3Model(QWEN3_CONFIG)
134148

135149
weights_dict = download_from_huggingface_from_snapshots(
136150
repo_id=repo_id,
137151
local_dir=local_dir
138152
)
139153
load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)
154+
model.to(device) # only required for the MoE models
140155
del weights_dict # delete weight dictionary to free up disk space
141-
142-
device = (
143-
torch.device("cuda") if torch.cuda.is_available() else
144-
torch.device("mps") if torch.backends.mps.is_available() else
145-
torch.device("cpu")
146-
)
147-
148-
model.to(device);
149156
```
150157

151158

@@ -236,6 +243,33 @@ Large language models (LLMs) are advanced artificial intelligence systems design
236243

237244

238245

246+
For the larger models, you may prefer the streaming variant, which prints each token as soon as it's generated:
247+
248+
```python
249+
from llms_from_scratch.generate import generate_text_simple_stream
250+
251+
input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)
252+
253+
for token in generate_text_simple_stream(
254+
model=model,
255+
token_ids=input_token_ids_tensor,
256+
max_new_tokens=150,
257+
eos_token_id=tokenizer.eos_token_id
258+
):
259+
token_id = token.squeeze(0).tolist()
260+
print(
261+
tokenizer.decode(token_id),
262+
end="",
263+
flush=True
264+
)
265+
```
266+
267+
```
268+
<|im_start|>user
269+
Give me a short introduction to large language models.<|im_end|>
270+
Large language models (LLMs) are advanced artificial intelligence systems designed to generate human-like text. They are trained on vast amounts of text data, allowing them to understand and generate coherent, contextually relevant responses. LLMs are used in a variety of applications, including chatbots, virtual assistants, content generation, and more. They are powered by deep learning algorithms and can be fine-tuned for specific tasks, making them versatile tools for a wide range of industries.<|endoftext|>Human resources department of a company is planning to hire 100 new employees. The company has a budget of $100,000 for the recruitment process. The company has a minimum wage of $10 per hour. The company has a total of...
271+
```
272+
239273

240274

241275
&nbsp;
@@ -252,18 +286,19 @@ model.to(device)
252286
with
253287

254288
```python
255-
model = torch.compile(model)
256289
model.to(device)
290+
model = torch.compile(model)
257291
```
258292

259293
Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call.
260294

261295
The following table shows a performance comparison on an A100 for consequent `generate` calls:
262296

263-
| | Tokens/sec | Memory |
264-
| ------------------------ | ---------- | ------- |
265-
| Qwen3Model 0.6B | 25 | 1.49 GB |
266-
| Qwen3Model 0.6B compiled | 107 | 1.99 GB |
297+
| | Hardware | Tokens/sec | Memory |
298+
| ------------------------ | ----------------|----------- | -------- |
299+
| Qwen3Model 0.6B | Nvidia A100 GPU | 25 | 1.49 GB |
300+
| Qwen3Model 0.6B compiled | Nvidia A100 GPU | 107 | 1.99 GB |
301+
267302

268303
&nbsp;
269304
#### Pro tip 2: speed up inference with KV cache
@@ -305,6 +340,8 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
305340

306341
Note that all settings above have been tested to produce the same text outputs.
307342

343+
344+
308345
&nbsp;
309346

310347
#### Pro tip 3: batched inference

ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"\n",
5656
"<br>\n",
5757
"\n",
58-
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen3-coder-flash-overview.webp?123\" width=\"700px\">\n",
58+
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen3-coder-flash-overview.webp?123\" width=\"600px\">\n",
5959
"\n",
6060
"<br>\n",
6161
" \n",

ch05/11_qwen3/standalone-qwen3-moe.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"\n",
5656
"<br>\n",
5757
"\n",
58-
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen3-coder-flash-overview.webp?123\" width=\"700px\">\n",
58+
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen3-coder-flash-overview.webp?123\" width=\"600px\">\n",
5959
"\n",
6060
"<br>\n",
6161
" \n",

ch05/11_qwen3/standalone-qwen3.ipynb

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -925,12 +925,6 @@
925925
" self.add_thinking = add_thinking\n",
926926
"\n",
927927
" tok_file = Path(tokenizer_file_path)\n",
928-
" if not tok_file.is_file() and repo_id:\n",
929-
" download_from_huggingface(\n",
930-
" repo_id=repo_id,\n",
931-
" filename=tok_file.name,\n",
932-
" local_dir=str(tok_file.parent),\n",
933-
" )\n",
934928
" self._tok = Tokenizer.from_file(str(tok_file))\n",
935929
" self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS}\n",
936930
"\n",

pkg/llms_from_scratch/README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,16 @@ from llms_from_scratch.qwen3 import (
160160

161161
# KV cache drop-in replacements
162162
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
163-
from llms_from_scratch.kv_cache.generate import generate_text_simple
163+
from llms_from_scratch.kv_cache.generate import (
164+
generate_text_simple,
165+
generate_text_simple_stream
166+
)
164167

165168
# KV cache drop-in replacements with batched inference support
166-
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple
169+
from llms_from_scratch.kv_cache_batched.generate import (
170+
generate_text_simple,
171+
generate_text_simple_stream
172+
)
167173
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model
168174
```
169175

pkg/llms_from_scratch/kv_cache/generate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,27 @@ def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cach
2828
idx = torch.cat([idx, next_idx], dim=1)
2929

3030
return idx
31+
32+
33+
def generate_text_simple_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
34+
model.eval()
35+
36+
with torch.no_grad():
37+
cache = KVCache(n_layers=model.cfg["n_layers"])
38+
model.reset_kv_cache()
39+
40+
# Prime the cache with the initial context
41+
logits = model(token_ids, cache=cache)
42+
43+
for _ in range(max_new_tokens):
44+
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
45+
46+
if eos_token_id is not None and torch.all(next_token == eos_token_id):
47+
break
48+
49+
yield next_token
50+
51+
token_ids = torch.cat([token_ids, next_token], dim=1)
52+
53+
# Feed only the new token to the model; cache handles history
54+
logits = model(next_token, cache=cache)

pkg/llms_from_scratch/tests/test_qwen3.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Qwen3Tokenizer
1414
)
1515
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
16+
from llms_from_scratch.kv_cache.utils import KVCache
1617
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
1718

1819
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
@@ -50,6 +51,116 @@ def extra_repr(self):
5051
transformers_installed = importlib.util.find_spec("transformers") is not None
5152

5253

54+
@pytest.fixture
55+
def dummy_input():
56+
torch.manual_seed(123)
57+
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
58+
59+
60+
@pytest.fixture
61+
def dummy_cfg_base():
62+
return {
63+
"vocab_size": 100,
64+
"emb_dim": 32,
65+
"hidden_dim": 64,
66+
"n_layers": 2,
67+
"n_heads": 4,
68+
"head_dim": 8,
69+
"n_kv_groups": 1,
70+
"qk_norm": False,
71+
"dtype": torch.float32,
72+
"rope_base": 10000,
73+
"context_length": 64,
74+
"num_experts": 0,
75+
}
76+
77+
78+
@pytest.fixture
79+
def dummy_cfg_moe(dummy_cfg_base):
80+
cfg = dummy_cfg_base.copy()
81+
cfg.update({
82+
"num_experts": 4,
83+
"num_experts_per_tok": 2,
84+
"moe_intermediate_size": 64,
85+
})
86+
return cfg
87+
88+
89+
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
90+
model = Qwen3Model(dummy_cfg_base)
91+
out = model(dummy_input)
92+
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
93+
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
94+
95+
96+
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
97+
model = Qwen3Model(dummy_cfg_moe)
98+
out = model(dummy_input)
99+
assert out.shape == (1, dummy_input.size(1), dummy_cfg_moe["vocab_size"]), \
100+
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
101+
assert any(hasattr(block.ff, 'gate') for block in model.trf_blocks), \
102+
"Expected MoEFeedForward in at least one transformer block"
103+
104+
105+
def test_qwen3_base_kvcache_equivalence(dummy_cfg_base):
106+
model_regular = Qwen3Model(dummy_cfg_base)
107+
model_regular.eval()
108+
109+
model_kv = Qwen3ModelKV(dummy_cfg_base)
110+
model_kv.eval()
111+
model_kv.load_state_dict(model_regular.state_dict()) # ensure same weights
112+
113+
model_kv.reset_kv_cache()
114+
cache = KVCache(n_layers=dummy_cfg_base["n_layers"])
115+
116+
torch.manual_seed(123)
117+
input_ids = torch.randint(0, dummy_cfg_base["vocab_size"], (1, 6)) # batch_size=1, seq_len=6
118+
119+
# full-sequence output
120+
out_full = model_regular(input_ids)
121+
122+
# stepwise with KV cache
123+
logits_stepwise = []
124+
for t in range(input_ids.size(1)):
125+
input_token = input_ids[:, t:t + 1] # shape (1,1)
126+
logits = model_kv(input_token, cache=cache)
127+
logits_stepwise.append(logits)
128+
129+
out_kv = torch.cat(logits_stepwise, dim=1)
130+
131+
assert out_full.shape == out_kv.shape, f"Shape mismatch: {out_full.shape} vs {out_kv.shape}"
132+
assert torch.allclose(out_full, out_kv, atol=1e-5, rtol=1e-3)
133+
134+
135+
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
136+
def test_qwen3_moe_kvcache_equivalence(cfg_name):
137+
model_regular = Qwen3Model(cfg_name)
138+
model_regular.eval()
139+
140+
torch.manual_seed(123)
141+
input_ids = torch.randint(0, cfg_name["vocab_size"], (1, 6)) # batch_size=1, seq_len=6
142+
143+
# No KV cache forward
144+
out_full = model_regular(input_ids)
145+
146+
# Now with KV cache
147+
model_kv = Qwen3ModelKV(cfg_name)
148+
model_kv.eval()
149+
model_kv.reset_kv_cache()
150+
cache = KVCache(n_layers=cfg_name["n_layers"])
151+
152+
logits_stepwise = []
153+
for t in range(input_ids.size(1)):
154+
input_token = input_ids[:, t:t+1] # shape (1, 1)
155+
logits = model_kv(input_token, cache=cache)
156+
logits_stepwise.append(logits)
157+
158+
# Concatenate all stepwise outputs
159+
out_kv = torch.cat(logits_stepwise, dim=1)
160+
161+
assert torch.allclose(out_full, out_kv, atol=1e-5, rtol=1e-3)
162+
163+
53164
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
54165
def test_rope():
55166

0 commit comments

Comments
 (0)