Skip to content

Commit d037050

Browse files
committed
add HF equivalency tests for standalone nbs
1 parent a6b883c commit d037050

File tree

8 files changed

+382
-81
lines changed

8 files changed

+382
-81
lines changed

.github/workflows/basic-tests-linux-uv.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ jobs:
5151
pytest --ruff ch04/01_main-chapter-code/tests.py
5252
pytest --ruff ch04/03_kv-cache/tests.py
5353
pytest --ruff ch05/01_main-chapter-code/tests.py
54-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
54+
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py.py
55+
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32.py
56+
pytest --ruff ch05/11_qwen3/tests/test_qwen3.py
5557
pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
5658
pytest --ruff ch06/01_main-chapter-code/tests.py
5759

.github/workflows/basic-tests-macos-uv.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ jobs:
5050
pytest --ruff setup/02_installing-python-libraries/tests.py
5151
pytest --ruff ch04/01_main-chapter-code/tests.py
5252
pytest --ruff ch05/01_main-chapter-code/tests.py
53-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
53+
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py.py
54+
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32.py
55+
pytest --ruff ch05/11_qwen3/tests/test_qwen3.py
5456
pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
5557
pytest --ruff ch06/01_main-chapter-code/tests.py
5658
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2+
# Source for "Build a Large Language Model From Scratch"
3+
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
4+
# Code: https://github.com/rasbt/LLMs-from-scratch
5+
6+
import importlib
7+
from pathlib import Path
8+
9+
import pytest
10+
import torch
11+
12+
from llms_from_scratch.utils import import_definitions_from_notebook
13+
14+
15+
transformers_installed = importlib.util.find_spec("transformers") is not None
16+
17+
18+
@pytest.fixture
19+
def nb_imports():
20+
nb_dir = Path(__file__).resolve().parents[1]
21+
mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
22+
return mod
23+
24+
25+
@pytest.fixture
26+
def dummy_input():
27+
torch.manual_seed(123)
28+
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
29+
30+
31+
@pytest.fixture
32+
def dummy_cfg_base():
33+
return {
34+
"vocab_size": 100,
35+
"emb_dim": 32, # hidden_size
36+
"hidden_dim": 64, # intermediate_size (FFN)
37+
"n_layers": 2,
38+
"n_heads": 4,
39+
"head_dim": 8,
40+
"n_kv_groups": 1,
41+
"dtype": torch.float32,
42+
"rope_base": 500_000.0,
43+
"rope_freq": {
44+
"factor": 8.0,
45+
"low_freq_factor": 1.0,
46+
"high_freq_factor": 4.0,
47+
"original_context_length": 8192,
48+
},
49+
"context_length": 64,
50+
}
51+
52+
53+
@torch.inference_mode()
54+
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
55+
torch.manual_seed(123)
56+
model = nb_imports.Llama3Model(dummy_cfg_base)
57+
out = model(dummy_input)
58+
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
59+
60+
61+
@torch.inference_mode()
62+
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
63+
def test_llama3_base_equivalence_with_transformers(nb_imports):
64+
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
65+
cfg = {
66+
"vocab_size": 257,
67+
"context_length": 8192,
68+
"emb_dim": 32,
69+
"n_heads": 4,
70+
"n_layers": 2,
71+
"hidden_dim": 64,
72+
"n_kv_groups": 2,
73+
"rope_base": 500_000.0,
74+
"rope_freq": {
75+
"factor": 32.0,
76+
"low_freq_factor": 1.0,
77+
"high_freq_factor": 4.0,
78+
"original_context_length": 8192,
79+
},
80+
"dtype": torch.float32,
81+
}
82+
83+
ours = nb_imports.Llama3Model(cfg)
84+
85+
hf_cfg = LlamaConfig(
86+
vocab_size=cfg["vocab_size"],
87+
hidden_size=cfg["emb_dim"],
88+
num_attention_heads=cfg["n_heads"],
89+
num_key_value_heads=cfg["n_kv_groups"],
90+
num_hidden_layers=cfg["n_layers"],
91+
intermediate_size=cfg["hidden_dim"],
92+
max_position_embeddings=cfg["context_length"],
93+
rms_norm_eps=1e-5,
94+
attention_bias=False,
95+
rope_theta=cfg["rope_base"],
96+
tie_word_embeddings=False,
97+
attn_implementation="eager",
98+
torch_dtype=torch.float32,
99+
rope_scaling={
100+
"type": "llama3",
101+
"factor": cfg["rope_freq"]["factor"],
102+
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
103+
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
104+
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
105+
},
106+
)
107+
theirs = LlamaForCausalLM(hf_cfg)
108+
109+
hf_state = theirs.state_dict()
110+
nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
111+
112+
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
113+
ours_logits = ours(x)
114+
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
115+
116+
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
File renamed without changes.

ch05/11_qwen3/tests/test_qwen3.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2+
# Source for "Build a Large Language Model From Scratch"
3+
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
4+
# Code: https://github.com/rasbt/LLMs-from-scratch
5+
6+
import importlib
7+
from pathlib import Path
8+
9+
import pytest
10+
import torch
11+
12+
from llms_from_scratch.utils import import_definitions_from_notebook
13+
14+
15+
transformers_installed = importlib.util.find_spec("transformers") is not None
16+
17+
18+
@pytest.fixture
19+
def nb_imports():
20+
nb_dir = Path(__file__).resolve().parents[1]
21+
mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3.ipynb")
22+
return mod
23+
24+
25+
@pytest.fixture
26+
def dummy_input():
27+
torch.manual_seed(123)
28+
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
29+
30+
31+
@pytest.fixture
32+
def dummy_cfg_base():
33+
return {
34+
"vocab_size": 100,
35+
"emb_dim": 32,
36+
"hidden_dim": 64,
37+
"n_layers": 2,
38+
"n_heads": 4,
39+
"head_dim": 8,
40+
"n_kv_groups": 1,
41+
"qk_norm": False,
42+
"dtype": torch.float32,
43+
"rope_base": 10000,
44+
"context_length": 64,
45+
"num_experts": 0,
46+
}
47+
48+
49+
@pytest.fixture
50+
def dummy_cfg_moe(dummy_cfg_base):
51+
cfg = dummy_cfg_base.copy()
52+
cfg.update({
53+
"num_experts": 4,
54+
"num_experts_per_tok": 2,
55+
"moe_intermediate_size": 64,
56+
})
57+
return cfg
58+
59+
60+
@torch.inference_mode()
61+
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
62+
torch.manual_seed(123)
63+
model = nb_imports.Qwen3Model(dummy_cfg_base)
64+
out = model(dummy_input)
65+
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
66+
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
67+
68+
69+
@torch.inference_mode()
70+
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
71+
def test_qwen3_base_equivalence_with_transformers(nb_imports):
72+
from transformers import Qwen3Config, Qwen3ForCausalLM
73+
74+
# Tiny config so the test is fast
75+
cfg = {
76+
"vocab_size": 257,
77+
"context_length": 8,
78+
"emb_dim": 32,
79+
"n_heads": 4,
80+
"n_layers": 2,
81+
"hidden_dim": 64,
82+
"head_dim": 8,
83+
"qk_norm": True,
84+
"n_kv_groups": 2,
85+
"rope_base": 1_000_000.0,
86+
"rope_local_base": 10_000.0,
87+
"sliding_window": 4,
88+
"layer_types": ["full_attention", "full_attention"],
89+
"dtype": torch.float32,
90+
"query_pre_attn_scalar": 256,
91+
}
92+
model = nb_imports.Qwen3Model(cfg)
93+
94+
hf_cfg = Qwen3Config(
95+
vocab_size=cfg["vocab_size"],
96+
max_position_embeddings=cfg["context_length"],
97+
hidden_size=cfg["emb_dim"],
98+
num_attention_heads=cfg["n_heads"],
99+
num_hidden_layers=cfg["n_layers"],
100+
intermediate_size=cfg["hidden_dim"],
101+
head_dim=cfg["head_dim"],
102+
num_key_value_heads=cfg["n_kv_groups"],
103+
rope_theta=cfg["rope_base"],
104+
rope_local_base_freq=cfg["rope_local_base"],
105+
layer_types=cfg["layer_types"],
106+
sliding_window=cfg["sliding_window"],
107+
tie_word_embeddings=False,
108+
attn_implementation="eager",
109+
torch_dtype=torch.float32,
110+
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
111+
rope_scaling={"rope_type": "default"},
112+
)
113+
hf_model = Qwen3ForCausalLM(hf_cfg)
114+
115+
hf_state = hf_model.state_dict()
116+
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
117+
nb_imports.load_weights_into_qwen(model, param_config, hf_state)
118+
119+
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
120+
ours_logits = model(x)
121+
theirs_logits = hf_model(x).logits
122+
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

ch05/12_gemma3/tests/test_gemma3.py

Lines changed: 13 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -4,77 +4,21 @@
44
# Code: https://github.com/rasbt/LLMs-from-scratch
55

66
import importlib
7-
import types
8-
import re
97
from pathlib import Path
108

11-
import nbformat
129
import pytest
1310
import torch
1411

12+
from llms_from_scratch.utils import import_definitions_from_notebook
13+
14+
1515
transformers_installed = importlib.util.find_spec("transformers") is not None
1616

1717

18-
def _extract_defs_and_classes_from_code(src):
19-
lines = src.splitlines()
20-
kept = []
21-
i = 0
22-
while i < len(lines):
23-
line = lines[i]
24-
stripped = line.lstrip()
25-
# Keep decorators attached to the next def/class
26-
if stripped.startswith("@"):
27-
# Look ahead: if the next non-empty line starts with def/class, keep decorator
28-
j = i + 1
29-
while j < len(lines) and not lines[j].strip():
30-
j += 1
31-
if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
32-
kept.append(line)
33-
i += 1
34-
continue
35-
if stripped.startswith("def ") or stripped.startswith("class "):
36-
kept.append(line)
37-
# capture until we leave the indentation block
38-
base_indent = len(line) - len(stripped)
39-
i += 1
40-
while i < len(lines):
41-
nxt = lines[i]
42-
if nxt.strip() == "":
43-
kept.append(nxt)
44-
i += 1
45-
continue
46-
indent = len(nxt) - len(nxt.lstrip())
47-
if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
48-
break
49-
kept.append(nxt)
50-
i += 1
51-
continue
52-
i += 1
53-
code = "\n".join(kept)
54-
code = re.sub(r"def\s+load_weights_into_gemma\s*\(\s*Gemma3Model\s*,",
55-
"def load_weights_into_gemma(model,",
56-
code)
57-
return code
58-
59-
60-
def import_definitions_from_notebook(nb_dir_or_path, notebook_name):
61-
nb_path = Path(nb_dir_or_path)
62-
if nb_path.is_dir():
63-
nb_file = nb_path / notebook_name
64-
else:
65-
nb_file = nb_path
66-
if not nb_file.exists():
67-
raise FileNotFoundError(f"Notebook not found: {nb_file}")
68-
69-
nb = nbformat.read(nb_file, as_version=4)
70-
pieces = ["import torch", "import torch.nn as nn"]
71-
for cell in nb.cells:
72-
if cell.cell_type == "code":
73-
pieces.append(_extract_defs_and_classes_from_code(cell.source))
74-
src = "\n\n".join(pieces)
75-
76-
mod = types.ModuleType("gemma3_defs")
77-
exec(src, mod.__dict__)
18+
@pytest.fixture
19+
def nb_imports():
20+
nb_dir = Path(__file__).resolve().parents[1]
21+
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
7822
return mod
7923

8024

@@ -106,25 +50,16 @@ def dummy_cfg_base():
10650

10751

10852
@torch.inference_mode()
109-
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input):
110-
nb_dir = Path(__file__).resolve().parents[1]
111-
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
112-
Gemma3Model = mod.Gemma3Model
113-
53+
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
11454
torch.manual_seed(123)
115-
model = Gemma3Model(dummy_cfg_base)
55+
model = nb_imports.Gemma3Model(dummy_cfg_base)
11656
out = model(dummy_input)
117-
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
57+
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
11858

11959

12060
@torch.inference_mode()
12161
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
122-
def test_gemma3_base_equivalence_with_transformers():
123-
nb_dir = Path(__file__).resolve().parents[1]
124-
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
125-
Gemma3Model = mod.Gemma3Model
126-
load_weights_into_gemma = mod.load_weights_into_gemma
127-
62+
def test_gemma3_base_equivalence_with_transformers(nb_imports):
12863
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
12964

13065
# Tiny config so the test is fast
@@ -145,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers():
14580
"dtype": torch.float32,
14681
"query_pre_attn_scalar": 256,
14782
}
148-
model = Gemma3Model(cfg)
83+
model = nb_imports.Gemma3Model(cfg)
14984

15085
hf_cfg = Gemma3TextConfig(
15186
vocab_size=cfg["vocab_size"],
@@ -170,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers():
170105

171106
hf_state = hf_model.state_dict()
172107
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
173-
load_weights_into_gemma(model, param_config, hf_state)
108+
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
174109

175110
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
176111
ours_logits = model(x)

0 commit comments

Comments
 (0)