|
20 | 20 | from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
21 | 21 | from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
22 | 22 |
|
| 23 | +from llms_from_scratch.utils import download_file |
| 24 | + |
23 | 25 | import importlib
|
| 26 | +import os |
| 27 | +import shutil |
| 28 | +import tempfile |
24 | 29 | import platform
|
25 | 30 | import pytest
|
26 | 31 | import torch
|
@@ -465,13 +470,6 @@ def test_chat_wrap_and_equivalence(add_gen, add_think):
|
465 | 470 | add_generation_prompt=add_gen,
|
466 | 471 | enable_thinking=add_think,
|
467 | 472 | )
|
468 |
| - ours = qt.encode(prompt) |
469 |
| - ref = hf_tok.apply_chat_template( |
470 |
| - messages, |
471 |
| - tokenize=True, |
472 |
| - add_generation_prompt=add_gen, |
473 |
| - enable_thinking=add_think, |
474 |
| - ) |
475 | 473 |
|
476 | 474 | if add_gen and not add_think:
|
477 | 475 | pass # skip edge case as this is not something we use in practice
|
@@ -534,6 +532,72 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
|
534 | 532 | assert ours_dec == ref_dec
|
535 | 533 |
|
536 | 534 |
|
| 535 | +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") |
| 536 | +def test_tokenizer_equivalence(): |
| 537 | + from transformers import AutoTokenizer |
| 538 | + |
| 539 | + prompt = "Give me a short introduction to large language models." |
| 540 | + messages = [ |
| 541 | + {"role": "user", "content": prompt}, |
| 542 | + ] |
| 543 | + |
| 544 | + for apply_chat_template in (True, False): |
| 545 | + for s in ("-Base", ""): |
| 546 | + repo_id = f"Qwen/Qwen3-0.6B{s}" |
| 547 | + tokenizer_ref = AutoTokenizer.from_pretrained(repo_id) |
| 548 | + tokenizer_url = f"https://huggingface.co/Qwen/Qwen3-0.6B{s}/resolve/main/tokenizer.json" |
| 549 | + download_file(tokenizer_url, out_dir=".") |
| 550 | + |
| 551 | + old_name = "tokenizer.json" |
| 552 | + |
| 553 | + if not s: |
| 554 | + new_name = "tokenizer-reasoning.json" |
| 555 | + else: |
| 556 | + new_name = "tokenizer-base.json" |
| 557 | + |
| 558 | + try: |
| 559 | + shutil.move(old_name, new_name) |
| 560 | + except Exception: |
| 561 | + with tempfile.NamedTemporaryFile(delete=False, dir=".") as tmp_file: |
| 562 | + shutil.copyfile(old_name, tmp_file.name) |
| 563 | + os.replace(tmp_file.name, new_name) |
| 564 | + os.remove(old_name) |
| 565 | + |
| 566 | + for states in ((True, True), (False, False)): |
| 567 | + tokenizer = Qwen3Tokenizer( |
| 568 | + tokenizer_file_path=new_name, |
| 569 | + repo_id=repo_id, |
| 570 | + apply_chat_template=apply_chat_template, |
| 571 | + add_generation_prompt=states[0], |
| 572 | + add_thinking=states[1] |
| 573 | + ) |
| 574 | + input_token_ids = tokenizer.encode(prompt) |
| 575 | + |
| 576 | + if apply_chat_template: |
| 577 | + input_token_ids_ref = tokenizer_ref.apply_chat_template( |
| 578 | + messages, |
| 579 | + tokenize=True, |
| 580 | + add_generation_prompt=states[0], |
| 581 | + enable_thinking=states[1], |
| 582 | + ) |
| 583 | + else: |
| 584 | + input_token_ids_ref = input_token_ids |
| 585 | + |
| 586 | + assert input_token_ids == input_token_ids_ref, states |
| 587 | + |
| 588 | + output_text = tokenizer.decode(input_token_ids) |
| 589 | + out_text_ref = tokenizer_ref.decode(input_token_ids_ref) |
| 590 | + assert output_text == out_text_ref, states |
| 591 | + |
| 592 | + assert tokenizer.encode("<|endoftext|>") == [tokenizer._special_to_id["<|endoftext|>"]] |
| 593 | + assert tokenizer.encode("<|im_end|>") == [tokenizer._special_to_id["<|im_end|>"]] |
| 594 | + |
| 595 | + expected_eos_token = "<|im_end|>" if "base" not in new_name else "<|endoftext|>" |
| 596 | + expected_pad_token = "<|endoftext|>" |
| 597 | + assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token |
| 598 | + assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token |
| 599 | + |
| 600 | + |
537 | 601 | @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
538 | 602 | @pytest.mark.parametrize("repo_id, tok_file", [
|
539 | 603 | ("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),
|
|
0 commit comments