Skip to content

Commit b14325e

Browse files
authored
Qwen3 and Llama3 equivalency teests with HF transformers (#768)
* Qwen3 and Llama3 equivalency teests with HF transformers * update
1 parent 5febcf8 commit b14325e

File tree

6 files changed

+199
-8
lines changed

6 files changed

+199
-8
lines changed

.github/workflows/basic-tests-pixi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
runs-on: ${{ matrix.os }}
2929
strategy:
3030
matrix:
31-
os: [ubuntu-latest, macos-latest, windows-latest]
31+
os: [ubuntu-latest, windows-latest]
3232

3333
steps:
3434
- uses: actions/checkout@v4

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Configs and keys
32
ch05/07_gpt_to_llama/config.json
43
ch07/02_dataset-utilities/config.json
@@ -78,6 +77,11 @@ ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth
7877
ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth
7978
ch07/01_main-chapter-code/gpt2/
8079

80+
Qwen3-0.6B-Base/
81+
Qwen3-0.6B/
82+
tokenizer-base.json
83+
tokenizer.json
84+
8185
# Datasets
8286
the-verdict.txt
8387

pkg/llms_from_scratch/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ For more information about KV caching, please see the [KV cache README](../../ch
132132

133133
```python
134134
from llms_from_scratch.llama3 import (
135-
Llama3Model,
135+
load_weights_into_llama,
136+
Llama3Model,
136137
Llama3ModelFast,
137138
Llama3Tokenizer,
138139
ChatFormat,
@@ -154,6 +155,7 @@ For more information about KV caching, please see the [KV cache README](../../ch
154155

155156
```python
156157
from llms_from_scratch.qwen3 import (
158+
load_weights_into_qwen
157159
Qwen3Model,
158160
Qwen3Tokenizer,
159161
)

pkg/llms_from_scratch/llama3.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,77 @@ def forward(self, in_idx):
497497
x = self.final_norm(x)
498498
logits = self.out_head(x.to(self.cfg["dtype"]))
499499
return logits
500+
501+
502+
def assign(left, right, tensor_name="unknown"):
503+
if left.shape != right.shape:
504+
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
505+
506+
if isinstance(right, torch.Tensor):
507+
return torch.nn.Parameter(right.clone().detach())
508+
else:
509+
return torch.nn.Parameter(torch.tensor(right))
510+
511+
512+
def load_weights_into_llama(model, param_config, params):
513+
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
514+
515+
for l in range(param_config["n_layers"]):
516+
517+
# Load attention weights
518+
model.trf_blocks[l].att.W_query.weight = assign(
519+
model.trf_blocks[l].att.W_query.weight,
520+
params[f"model.layers.{l}.self_attn.q_proj.weight"],
521+
f"model.layers.{l}.self_attn.q_proj.weight"
522+
)
523+
model.trf_blocks[l].att.W_key.weight = assign(
524+
model.trf_blocks[l].att.W_key.weight,
525+
params[f"model.layers.{l}.self_attn.k_proj.weight"],
526+
f"model.layers.{l}.self_attn.k_proj.weight"
527+
)
528+
model.trf_blocks[l].att.W_value.weight = assign(
529+
model.trf_blocks[l].att.W_value.weight,
530+
params[f"model.layers.{l}.self_attn.v_proj.weight"],
531+
f"model.layers.{l}.self_attn.v_proj.weight"
532+
)
533+
model.trf_blocks[l].att.out_proj.weight = assign(
534+
model.trf_blocks[l].att.out_proj.weight,
535+
params[f"model.layers.{l}.self_attn.o_proj.weight"],
536+
f"model.layers.{l}.self_attn.o_proj.weight"
537+
)
538+
model.trf_blocks[l].norm1.weight = assign(
539+
model.trf_blocks[l].norm1.weight,
540+
params[f"model.layers.{l}.input_layernorm.weight"],
541+
f"model.layers.{l}.input_layernorm.weight"
542+
)
543+
544+
# Load FeedForward weights
545+
model.trf_blocks[l].ff.fc1.weight = assign(
546+
model.trf_blocks[l].ff.fc1.weight,
547+
params[f"model.layers.{l}.mlp.gate_proj.weight"],
548+
f"model.layers.{l}.mlp.gate_proj.weight"
549+
)
550+
model.trf_blocks[l].ff.fc2.weight = assign(
551+
model.trf_blocks[l].ff.fc2.weight,
552+
params[f"model.layers.{l}.mlp.up_proj.weight"],
553+
f"model.layers.{l}.mlp.up_proj.weight"
554+
)
555+
model.trf_blocks[l].ff.fc3.weight = assign(
556+
model.trf_blocks[l].ff.fc3.weight,
557+
params[f"model.layers.{l}.mlp.down_proj.weight"],
558+
f"model.layers.{l}.mlp.down_proj.weight"
559+
)
560+
model.trf_blocks[l].norm2.weight = assign(
561+
model.trf_blocks[l].norm2.weight,
562+
params[f"model.layers.{l}.post_attention_layernorm.weight"],
563+
f"model.layers.{l}.post_attention_layernorm.weight"
564+
)
565+
566+
# Load output layer weights
567+
model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")
568+
569+
if "lm_head.weight" in params.keys():
570+
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
571+
else:
572+
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
573+
print("Model uses weight tying.")

pkg/llms_from_scratch/tests/test_llama3.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
from llms_from_scratch.ch04 import generate_text_simple
77
from llms_from_scratch.llama3 import (
8-
compute_rope_params,
98
apply_rope,
10-
LLAMA32_CONFIG_1B,
9+
compute_rope_params,
1110
GroupedQueryAttention,
1211
GroupedQueryAttentionFast,
12+
load_weights_into_llama,
13+
LLAMA32_CONFIG_1B,
1314
Llama3Model,
1415
)
1516
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
@@ -246,3 +247,61 @@ def test_rmsnorm_equivalence():
246247
out2 = lit_norm(x)
247248

248249
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
250+
251+
252+
@torch.inference_mode()
253+
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
254+
def test_llama3_base_equivalence_with_transformers():
255+
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
256+
cfg = {
257+
"vocab_size": 257,
258+
"context_length": 8192,
259+
"emb_dim": 32,
260+
"n_heads": 4,
261+
"n_layers": 2,
262+
"hidden_dim": 64,
263+
"n_kv_groups": 2,
264+
"rope_base": 500_000.0,
265+
"rope_freq": {
266+
"factor": 32.0,
267+
"low_freq_factor": 1.0,
268+
"high_freq_factor": 4.0,
269+
"original_context_length": 8192,
270+
},
271+
"dtype": torch.float32,
272+
}
273+
274+
ours = Llama3Model(cfg)
275+
276+
hf_cfg = LlamaConfig(
277+
vocab_size=cfg["vocab_size"],
278+
hidden_size=cfg["emb_dim"],
279+
num_attention_heads=cfg["n_heads"],
280+
num_key_value_heads=cfg["n_kv_groups"],
281+
num_hidden_layers=cfg["n_layers"],
282+
intermediate_size=cfg["hidden_dim"],
283+
max_position_embeddings=cfg["context_length"],
284+
rms_norm_eps=1e-5,
285+
attention_bias=False,
286+
rope_theta=cfg["rope_base"],
287+
tie_word_embeddings=False,
288+
attn_implementation="eager",
289+
torch_dtype=torch.float32,
290+
rope_scaling={
291+
"type": "llama3",
292+
"factor": cfg["rope_freq"]["factor"],
293+
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
294+
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
295+
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
296+
},
297+
)
298+
theirs = LlamaForCausalLM(hf_cfg)
299+
300+
hf_state = theirs.state_dict()
301+
load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
302+
303+
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
304+
ours_logits = ours(x)
305+
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
306+
307+
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

pkg/llms_from_scratch/tests/test_qwen3.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from llms_from_scratch.ch04 import generate_text_simple
77
from llms_from_scratch.qwen3 import (
8-
compute_rope_params,
98
apply_rope,
9+
compute_rope_params,
10+
load_weights_into_qwen,
1011
QWEN_CONFIG_06_B,
11-
RMSNorm,
1212
Qwen3Model,
13-
Qwen3Tokenizer
13+
Qwen3Tokenizer,
14+
RMSNorm,
1415
)
1516
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
1617
from llms_from_scratch.kv_cache.utils import KVCache
@@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
8788
return cfg
8889

8990

91+
@torch.inference_mode()
9092
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
9193
torch.manual_seed(123)
9294
model = Qwen3Model(dummy_cfg_base)
@@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
9597
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
9698

9799

100+
@torch.inference_mode()
98101
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
99102
torch.manual_seed(123)
100103
model = Qwen3Model(dummy_cfg_moe)
@@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
105108
"Expected MoEFeedForward in at least one transformer block"
106109

107110

111+
@torch.inference_mode()
108112
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
109113
def test_qwen3_kvcache_equivalence(cfg_name, request):
110114
cfg = request.getfixturevalue(cfg_name)
@@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
438442
expected_pad_token = "<|endoftext|>"
439443
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
440444
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
445+
446+
447+
@torch.inference_mode()
448+
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
449+
def test_qwen3_base_equivalence_with_transformers():
450+
451+
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
452+
453+
# Tiny config so the test is fast
454+
cfg = {
455+
"vocab_size": 257,
456+
"context_length": 8,
457+
"emb_dim": 32,
458+
"n_heads": 4,
459+
"n_layers": 2,
460+
"hidden_dim": 64,
461+
"head_dim": 8,
462+
"qk_norm": True,
463+
"n_kv_groups": 2,
464+
"rope_base": 1_000_000.0,
465+
"dtype": torch.float32,
466+
}
467+
model = Qwen3Model(cfg)
468+
469+
hf_cfg = Qwen3Config(
470+
vocab_size=cfg["vocab_size"],
471+
max_position_embeddings=cfg["context_length"],
472+
hidden_size=cfg["emb_dim"],
473+
num_attention_heads=cfg["n_heads"],
474+
num_hidden_layers=cfg["n_layers"],
475+
intermediate_size=cfg["hidden_dim"],
476+
head_dim=cfg["head_dim"],
477+
num_key_value_heads=cfg["n_kv_groups"],
478+
rope_theta=cfg["rope_base"],
479+
tie_word_embeddings=False,
480+
attn_implementation="eager",
481+
torch_dtype=torch.float32,
482+
)
483+
hf_model = Qwen3ForCausalLM(hf_cfg)
484+
485+
hf_state = hf_model.state_dict()
486+
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
487+
load_weights_into_qwen(model, param_config, hf_state)
488+
489+
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
490+
ours_logits = model(x)
491+
theirs_logits = hf_model(x).logits
492+
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)