Skip to content

Commit 80d4732

Browse files
authored
add HF equivalency tests for standalone nbs (#774)
* add HF equivalency tests for standalone nbs * update * update * update * update
1 parent a6b883c commit 80d4732

15 files changed

+389
-91
lines changed

.github/workflows/basic-tests-linux-uv.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ jobs:
5151
pytest --ruff ch04/01_main-chapter-code/tests.py
5252
pytest --ruff ch04/03_kv-cache/tests.py
5353
pytest --ruff ch05/01_main-chapter-code/tests.py
54-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
55-
pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
54+
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
55+
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
56+
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
57+
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
5658
pytest --ruff ch06/01_main-chapter-code/tests.py
5759
5860
- name: Validate Selected Jupyter Notebooks (uv)

.github/workflows/basic-tests-macos-uv.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ jobs:
5050
pytest --ruff setup/02_installing-python-libraries/tests.py
5151
pytest --ruff ch04/01_main-chapter-code/tests.py
5252
pytest --ruff ch05/01_main-chapter-code/tests.py
53-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
54-
pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
53+
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
54+
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
55+
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
56+
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
5557
pytest --ruff ch06/01_main-chapter-code/tests.py
5658
5759
- name: Validate Selected Jupyter Notebooks (uv)

.github/workflows/basic-tests-old-pytorch.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ jobs:
4747
pytest --ruff setup/02_installing-python-libraries/tests.py
4848
pytest --ruff ch04/01_main-chapter-code/tests.py
4949
pytest --ruff ch05/01_main-chapter-code/tests.py
50-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
5150
pytest --ruff ch06/01_main-chapter-code/tests.py
5251
5352
- name: Validate Selected Jupyter Notebooks

.github/workflows/basic-tests-pip.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ jobs:
4141
source .venv/bin/activate
4242
pip install --upgrade pip
4343
pip install -r requirements.txt
44-
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
4544
pip install pytest pytest-ruff nbval
4645
4746
- name: Test Selected Python Scripts
@@ -50,7 +49,6 @@ jobs:
5049
pytest --ruff setup/02_installing-python-libraries/tests.py
5150
pytest --ruff ch04/01_main-chapter-code/tests.py
5251
pytest --ruff ch05/01_main-chapter-code/tests.py
53-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
5452
pytest --ruff ch06/01_main-chapter-code/tests.py
5553
5654
- name: Validate Selected Jupyter Notebooks

.github/workflows/basic-tests-pixi.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ jobs:
5050
pytest --ruff setup/02_installing-python-libraries/tests.py
5151
pytest --ruff ch04/01_main-chapter-code/tests.py
5252
pytest --ruff ch05/01_main-chapter-code/tests.py
53-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
5453
pytest --ruff ch06/01_main-chapter-code/tests.py
5554
5655
- name: Validate Selected Jupyter Notebooks

.github/workflows/basic-tests-pytorch-rc.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ jobs:
3333
run: |
3434
curl -LsSf https://astral.sh/uv/install.sh | sh
3535
uv sync --dev --python=3.10 # tests for backwards compatibility
36-
uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
3736
uv add pytest-ruff nbval
3837
uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
3938
@@ -43,7 +42,6 @@ jobs:
4342
pytest --ruff setup/02_installing-python-libraries/tests.py
4443
pytest --ruff ch04/01_main-chapter-code/tests.py
4544
pytest --ruff ch05/01_main-chapter-code/tests.py
46-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
4745
pytest --ruff ch06/01_main-chapter-code/tests.py
4846
4947
- name: Validate Selected Jupyter Notebooks

.github/workflows/basic-tests-windows-uv-pip.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ jobs:
4343
pip install tensorflow-io-gcs-filesystem==0.31.0 # Explicit for Windows
4444
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
4545
pip install pytest-ruff nbval
46+
pip install -e .
4647
4748
- name: Run Python Tests
4849
shell: bash
@@ -51,7 +52,9 @@ jobs:
5152
pytest --ruff setup/02_installing-python-libraries/tests.py
5253
pytest --ruff ch04/01_main-chapter-code/tests.py
5354
pytest --ruff ch05/01_main-chapter-code/tests.py
54-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
55+
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
56+
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
57+
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
5558
pytest --ruff ch06/01_main-chapter-code/tests.py
5659
5760
- name: Run Jupyter Notebook Tests

.github/workflows/basic-tests-windows-uv.yml.disabled

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ jobs:
5151
pytest --ruff setup/02_installing-python-libraries/tests.py
5252
pytest --ruff ch04/01_main-chapter-code/tests.py
5353
pytest --ruff ch05/01_main-chapter-code/tests.py
54-
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
5554
pytest --ruff ch06/01_main-chapter-code/tests.py
5655

5756
- name: Run Jupyter Notebook Tests
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2+
# Source for "Build a Large Language Model From Scratch"
3+
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
4+
# Code: https://github.com/rasbt/LLMs-from-scratch
5+
6+
import importlib
7+
from pathlib import Path
8+
9+
import pytest
10+
import torch
11+
12+
from llms_from_scratch.utils import import_definitions_from_notebook
13+
14+
15+
transformers_installed = importlib.util.find_spec("transformers") is not None
16+
17+
18+
@pytest.fixture
19+
def nb_imports():
20+
nb_dir = Path(__file__).resolve().parents[1]
21+
mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
22+
return mod
23+
24+
25+
@pytest.fixture
26+
def dummy_input():
27+
torch.manual_seed(123)
28+
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
29+
30+
31+
@pytest.fixture
32+
def dummy_cfg_base():
33+
return {
34+
"vocab_size": 100,
35+
"emb_dim": 32, # hidden_size
36+
"hidden_dim": 64, # intermediate_size (FFN)
37+
"n_layers": 2,
38+
"n_heads": 4,
39+
"head_dim": 8,
40+
"n_kv_groups": 1,
41+
"dtype": torch.float32,
42+
"rope_base": 500_000.0,
43+
"rope_freq": {
44+
"factor": 8.0,
45+
"low_freq_factor": 1.0,
46+
"high_freq_factor": 4.0,
47+
"original_context_length": 8192,
48+
},
49+
"context_length": 64,
50+
}
51+
52+
53+
@torch.inference_mode()
54+
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
55+
torch.manual_seed(123)
56+
model = nb_imports.Llama3Model(dummy_cfg_base)
57+
out = model(dummy_input)
58+
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
59+
60+
61+
@torch.inference_mode()
62+
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
63+
def test_llama3_base_equivalence_with_transformers(nb_imports):
64+
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
65+
cfg = {
66+
"vocab_size": 257,
67+
"context_length": 8192,
68+
"emb_dim": 32,
69+
"n_heads": 4,
70+
"n_layers": 2,
71+
"hidden_dim": 64,
72+
"n_kv_groups": 2,
73+
"rope_base": 500_000.0,
74+
"rope_freq": {
75+
"factor": 32.0,
76+
"low_freq_factor": 1.0,
77+
"high_freq_factor": 4.0,
78+
"original_context_length": 8192,
79+
},
80+
"dtype": torch.float32,
81+
}
82+
83+
ours = nb_imports.Llama3Model(cfg)
84+
85+
hf_cfg = LlamaConfig(
86+
vocab_size=cfg["vocab_size"],
87+
hidden_size=cfg["emb_dim"],
88+
num_attention_heads=cfg["n_heads"],
89+
num_key_value_heads=cfg["n_kv_groups"],
90+
num_hidden_layers=cfg["n_layers"],
91+
intermediate_size=cfg["hidden_dim"],
92+
max_position_embeddings=cfg["context_length"],
93+
rms_norm_eps=1e-5,
94+
attention_bias=False,
95+
rope_theta=cfg["rope_base"],
96+
tie_word_embeddings=False,
97+
attn_implementation="eager",
98+
torch_dtype=torch.float32,
99+
rope_scaling={
100+
"type": "llama3",
101+
"factor": cfg["rope_freq"]["factor"],
102+
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
103+
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
104+
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
105+
},
106+
)
107+
theirs = LlamaForCausalLM(hf_cfg)
108+
109+
hf_state = theirs.state_dict()
110+
nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
111+
112+
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
113+
ours_logits = ours(x)
114+
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
115+
116+
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
File renamed without changes.

0 commit comments

Comments
 (0)