diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py
index e9401c6f55be..3eccfe68ae21 100644
--- a/src/transformers/convert_slow_tokenizer.py
+++ b/src/transformers/convert_slow_tokenizer.py
@@ -20,12 +20,10 @@
import warnings
from collections.abc import Collection
-from functools import lru_cache
from packaging import version
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
-from tqdm import tqdm
from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
@@ -1973,86 +1971,6 @@ def converted(self) -> Tokenizer:
return tokenizer
-class MistralConverter:
- def __init__(
- self,
- vocab_file=None,
- pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
- add_prefix_space=False,
- additional_special_tokens=None,
- **kwargs,
- ):
- self.vocab_file = vocab_file
- self.pattern = pattern
- self.add_prefix_space = add_prefix_space
- self.additional_special_tokens = (
- additional_special_tokens.keys()
- if isinstance(additional_special_tokens, dict)
- else additional_special_tokens
- )
-
- def extract_vocab_merges_from_model(self, tiktoken_url: str):
- import base64
- import json
-
- with open(self.vocab_file, "r", encoding="utf-8") as f:
- untyped = json.load(f)
- self.pattern = untyped["config"]["pattern"]
- self.additional_special_tokens = [
- AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
- ]
- bpe_ranks = untyped["vocab"]
- byte_encoder = bytes_to_unicode()
-
- @lru_cache
- def token_bytes_to_string(b):
- return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
-
- merges = []
- vocab = {}
- for idx, token in enumerate(self.additional_special_tokens):
- vocab[token.content] = idx
- bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
- rank_set = set(bpe_ranks)
- token_to_rank = {token: rank for rank, token in enumerate(bpe_ranks)}
- for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
- vocab[token_bytes_to_string(token)] = rank
- if len(token) == 1:
- continue
- local = []
- for index in range(1, len(token)):
- piece_l, piece_r = token[:index], token[index:]
- if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
- local.append((piece_l, piece_r, rank))
- local = sorted(local, key=lambda x: (token_to_rank[x[0]], token_to_rank[x[1]]), reverse=False)
- merges.extend(local)
- merges = sorted(merges, key=lambda val: val[2], reverse=False)
- merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
- return vocab, merges
-
- def tokenizer(self):
- vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
- tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
- if hasattr(tokenizer.model, "ignore_merges"):
- tokenizer.model.ignore_merges = True
- return tokenizer
-
- def converted(self) -> Tokenizer:
- tokenizer = self.tokenizer()
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
- pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
-
- tokenizer.add_tokens(self.additional_special_tokens)
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
-
- return tokenizer
-
-
SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
@@ -2133,9 +2051,11 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()
elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
+ from .integrations.mistral.tokenizer import MistralConverter
+
transformer_tokenizer.original_tokenizer = transformer_tokenizer
logger.info("Converting from Mistral tekken.json")
- return MistralConverter(transformer_tokenizer.vocab_file).converted()
+ return MistralConverter.from_tekken_file(transformer_tokenizer.vocab_file).converted()
else:
try:
logger.info("Converting from Tiktoken")
diff --git a/src/transformers/integrations/mistral.py b/src/transformers/integrations/mistral.py
deleted file mode 100644
index 3256c9839acd..000000000000
--- a/src/transformers/integrations/mistral.py
+++ /dev/null
@@ -1,121 +0,0 @@
-from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
-from tokenizers.models import BPE
-
-from transformers.convert_slow_tokenizer import bytes_to_unicode
-from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast
-
-
-class MistralConverter:
- """
- A general tiktoken converter.
- """
-
- def __init__(
- self,
- vocab=None,
- pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
- add_prefix_space=False,
- additional_special_tokens=None,
- **kwargs,
- ):
- self.vocab = vocab
- self.pattern = pattern
- self.add_prefix_space = add_prefix_space
- self.additional_special_tokens = additional_special_tokens
-
- def extract_vocab_merges_from_model(self, vocab: str):
- bpe_ranks = vocab
- byte_encoder = bytes_to_unicode()
-
- def token_bytes_to_string(b):
- return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
-
- merges = []
- vocab = {}
- for idx, (token, rank) in enumerate(bpe_ranks.items()):
- if token not in self.additional_special_tokens:
- vocab[token_bytes_to_string(token)] = idx
- if len(token) == 1:
- continue
- local = []
- for index in range(1, len(token)):
- piece_l, piece_r = token[:index], token[index:]
- if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
- local.append((piece_l, piece_r, rank))
- local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
- merges.extend(local)
- else:
- vocab[token] = idx
- merges = sorted(merges, key=lambda val: val[2], reverse=False)
- merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
- return vocab, merges
-
- def tokenizer(self):
- vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
- tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
- if hasattr(tokenizer.model, "ignore_merges"):
- tokenizer.model.ignore_merges = True
- return tokenizer
-
- def converted(self) -> Tokenizer:
- tokenizer = self.tokenizer()
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
- pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.add_special_tokens(self.additional_special_tokens)
-
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
-
- return tokenizer
-
-
-def convert_tekken_tokenizer(tokenizer_file: str):
- """Convert a "tekken" tokenizer to a fast Tokenizer."""
- # Tekken format -- need to use the Converter
-
- from mistral_common.tokens.tokenizers.base import SpecialTokens
- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
-
- # Load directly using their lib
- mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file)
-
- # Extract vocab and special tokens
- vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
- sorted_tokens = sorted(mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens, key=lambda x: x["rank"])
- all_special = [token["token_str"] for token in sorted_tokens]
-
- specials_tokens = {token: idx for idx, token in enumerate(all_special)}
-
- specials_tokens.update(vocab)
- vocab = specials_tokens
-
- # TODO(juliendenize): expose this in mistral-common to avoid accessing private attributes
- # and improve maintainability
- pattern = mistral_tokenizer.instruct_tokenizer.tokenizer._model._pat_str
-
- # Convert
- tokenizer = PreTrainedTokenizerFast(
- tokenizer_object=MistralConverter(
- vocab=vocab, additional_special_tokens=all_special, pattern=pattern
- ).converted()
- )
-
- # Post-process
- tokenizer.add_special_tokens({"additional_special_tokens": all_special})
-
- MAP_SPECAL = {
- "bos_token": SpecialTokens.bos.value,
- "eos_token": SpecialTokens.eos.value,
- "pad_token": SpecialTokens.pad.value,
- "unk_token": SpecialTokens.unk.value,
- }
-
- for special_key, special_token in MAP_SPECAL.items():
- if special_token in all_special:
- tokenizer.add_special_tokens({special_key: special_token})
-
- return tokenizer
diff --git a/src/transformers/integrations/mistral/__init__.py b/src/transformers/integrations/mistral/__init__.py
new file mode 100644
index 000000000000..824cf230bdae
--- /dev/null
+++ b/src/transformers/integrations/mistral/__init__.py
@@ -0,0 +1,17 @@
+"""Mistral native format integration: tokenizer conversion utilities."""
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+
+
+_import_structure = {
+ "tokenizer": ["MistralConverter"],
+}
+
+if TYPE_CHECKING:
+ from .tokenizer import MistralConverter
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/integrations/mistral/tokenizer.py b/src/transformers/integrations/mistral/tokenizer.py
new file mode 100644
index 000000000000..fb53b7ed23cf
--- /dev/null
+++ b/src/transformers/integrations/mistral/tokenizer.py
@@ -0,0 +1,201 @@
+# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Conversion between Mistral tekken tokenizers and HuggingFace tokenizer formats."""
+
+import base64
+import json
+from functools import lru_cache
+from typing import Any
+
+from tokenizers import AddedToken, Regex, Tokenizer, decoders, pre_tokenizers, processors
+from tokenizers.models import BPE
+
+from ...convert_slow_tokenizer import bytes_to_unicode
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+_MAP_SPECIALS = {
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": "",
+ "unk_token": "",
+}
+
+
+class MistralConverter:
+ """Converter from Mistral tekken BPE vocab to a HuggingFace `tokenizers.Tokenizer`.
+
+ Construct via `from_tekken_file()`, which parses a raw `tekken.json` file and
+ pre-computes the vocab and merges.
+ """
+
+ def __init__(
+ self,
+ pattern: str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
+ add_prefix_space: bool = False,
+ additional_special_tokens: list[AddedToken] | None = None,
+ **kwargs,
+ ):
+ """Initialize a MistralConverter.
+
+ Args:
+ pattern (`str`): Regex pattern for pre-tokenization.
+ add_prefix_space (`bool`): Whether to add a leading space.
+ additional_special_tokens (`list` or `None`): Extra special tokens.
+ """
+ self.pattern = pattern
+ self.add_prefix_space = add_prefix_space
+ self.additional_special_tokens = additional_special_tokens
+ self._precomputed_vocab: dict[str, int] | None = None
+ self._precomputed_merges: list[tuple[str, str]] | None = None
+ self._tekken_metadata: dict[str, Any] | None = None
+
+ @property
+ def tekken_metadata(self) -> dict[str, Any]:
+ """Non-vocabulary metadata from the original ``tekken.json`` file.
+
+ Raises:
+ AttributeError: If the instance was not created via ``from_tekken_file``.
+ """
+ if self._tekken_metadata is None:
+ raise AttributeError(
+ "`tekken_metadata` is only accessible when instance is created by `from_tekken_file` method."
+ )
+ return self._tekken_metadata
+
+ @classmethod
+ def from_tekken_file(
+ cls,
+ vocab_file: str,
+ add_prefix_space: bool = False,
+ ) -> "MistralConverter":
+ """Parse a raw `tekken.json` file and return a ready-to-use converter.
+
+ Reads the file, extracts the regex pattern and special tokens, then
+ pre-computes `vocab` and `merges` with correct index offsets (special
+ tokens occupy the first indices).
+
+ Args:
+ vocab_file (`str`): Path to a `tekken.json` file.
+ add_prefix_space (`bool`): Whether to add a prefix space during tokenization.
+
+ Returns:
+ `MistralConverter`: A ready-to-use converter with pre-computed vocab and merges.
+ """
+ with open(vocab_file, encoding="utf-8") as f:
+ untyped = json.load(f)
+
+ pattern = untyped["config"]["pattern"]
+
+ additional_special_tokens = [AddedToken(k["token_str"], special=True) for k in untyped["special_tokens"]]
+ bpe_ranks_raw = untyped["vocab"]
+ num_special = len(additional_special_tokens)
+
+ bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks_raw]
+ bpe_ranks_dict = {token: rank for rank, token in enumerate(bpe_ranks)}
+
+ vocab, merges = cls._extract_merges(bpe_ranks_dict)
+
+ # Offset vocab indices to account for special tokens occupying the first slots
+ vocab = {k: v + num_special for k, v in vocab.items()}
+ # Use each special token's explicit `rank` as its id so list order is irrelevant.
+ for entry in untyped["special_tokens"]:
+ vocab[entry["token_str"]] = entry["rank"]
+
+ instance = cls(
+ pattern=pattern,
+ add_prefix_space=add_prefix_space,
+ additional_special_tokens=additional_special_tokens,
+ )
+ # Store pre-computed vocab and merges so tokenizer() can use them directly
+ instance._precomputed_vocab = vocab
+ instance._precomputed_merges = merges
+
+ # Preserve tekken.json metadata so it can be reconstructed on save.
+ instance._tekken_metadata = {k: v for k, v in untyped.items() if k != "vocab"}
+ # Store which vocab entries had token_str=null so save_as_tekken can
+ # restore them instead of unconditionally decoding from bytes.
+ instance._tekken_metadata["_null_token_str_bytes"] = [
+ entry["token_bytes"] for entry in bpe_ranks_raw if entry.get("token_str") is None
+ ]
+
+ return instance
+
+ @staticmethod
+ def _extract_merges(bpe_ranks: dict[bytes, int]) -> tuple[dict[str, int], list[tuple[str, str]]]:
+ """Extract a unicode vocab and BPE merge list from byte-level BPE ranks.
+
+ For each multi-byte token, tries all binary splits ``(token[:i], token[i:])``
+ and keeps those where both halves exist in the vocabulary. Splits are sorted
+ locally by ``(rank_left, rank_right)`` and globally by merged-token rank.
+
+ Args:
+ bpe_ranks (`dict[bytes, int]`): Mapping of byte-level tokens to their
+ integer ranks in the BPE vocabulary.
+
+ Returns:
+ `tuple[dict[str, int], list[tuple[str, str]]]`: A pair of
+ ``(vocab, merges)`` where vocab maps unicode token strings to ranks
+ and merges is an ordered list of BPE merge pairs.
+ """
+ byte_encoder = bytes_to_unicode()
+
+ @lru_cache
+ def token_bytes_to_string(b: bytes) -> str:
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
+
+ vocab: dict[str, int] = {}
+ all_merges: list[tuple[bytes, bytes, int]] = []
+
+ for token, rank in bpe_ranks.items():
+ vocab[token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ local = []
+ for index in range(1, len(token)):
+ piece_l, piece_r = token[:index], token[index:]
+ if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
+ local.append((piece_l, piece_r, rank))
+ local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]))
+ all_merges.extend(local)
+
+ all_merges = sorted(all_merges, key=lambda val: val[2])
+ merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in all_merges]
+ return vocab, merges
+
+ def tokenizer(self) -> Tokenizer:
+ """Build a raw `tokenizers.Tokenizer` with BPE model (no pre/post-processing)."""
+ tokenizer = Tokenizer(BPE(self._precomputed_vocab, self._precomputed_merges, fuse_unk=False))
+ if hasattr(tokenizer.model, "ignore_merges"):
+ tokenizer.model.ignore_merges = True
+ return tokenizer
+
+ def converted(self) -> Tokenizer:
+ """Build a fully configured `tokenizers.Tokenizer` with pre-tokenizer and decoder."""
+ tokenizer = self.tokenizer()
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
+ pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
+ ]
+ )
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.add_special_tokens(self.additional_special_tokens)
+
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+ return tokenizer
diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py
index d7c9744e5e11..3d13cca8b4e1 100644
--- a/src/transformers/tokenization_utils_tokenizers.py
+++ b/src/transformers/tokenization_utils_tokenizers.py
@@ -201,11 +201,11 @@ def convert_to_native_format(cls, trust_remote_code=False, **kwargs):
# Tekken converter (Mistral)
if isinstance(vocab_file, str) and vocab_file.endswith("tekken.json") and os.path.isfile(vocab_file):
- from .convert_slow_tokenizer import MistralConverter
+ from .integrations.mistral.tokenizer import MistralConverter
- local_kwargs["vocab"], local_kwargs["merges"] = MistralConverter(
- vocab_file=vocab_file
- ).extract_vocab_merges_from_model(vocab_file)
+ converter = MistralConverter.from_tekken_file(vocab_file)
+ local_kwargs["tokenizer_object"] = converter.converted()
+ local_kwargs["tekken_metadata"] = converter.tekken_metadata
return local_kwargs
# SentencePiece model (with TikToken fallback)
diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/integrations/mistral/__init__.py b/tests/integrations/mistral/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/integrations/mistral/test_tokenizer.py b/tests/integrations/mistral/test_tokenizer.py
new file mode 100644
index 000000000000..c3e2f00f7fef
--- /dev/null
+++ b/tests/integrations/mistral/test_tokenizer.py
@@ -0,0 +1,381 @@
+# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for MistralConverter: tekken.json parsing and HuggingFace tokenizer conversion."""
+
+import base64
+import json
+import tempfile
+import unittest
+from pathlib import Path
+
+from huggingface_hub import hf_hub_download
+
+from transformers.integrations.mistral import MistralConverter
+from transformers.testing_utils import require_mistral_common, slow
+from transformers.utils.import_utils import is_mistral_common_available
+
+
+if is_mistral_common_available():
+ from transformers.tokenization_mistral_common import MistralCommonBackend
+
+
+_NUM_SPECIAL_TOKENS = 20
+
+_FAKE_TEKKEN_SPECIAL_TOKENS = [
+ {"rank": 0, "token_str": "", "is_control": True},
+ {"rank": 1, "token_str": "", "is_control": True},
+ {"rank": 2, "token_str": "", "is_control": True},
+ {"rank": 3, "token_str": "[INST]", "is_control": True},
+ {"rank": 4, "token_str": "[/INST]", "is_control": True},
+ {"rank": 5, "token_str": "[AVAILABLE_TOOLS]", "is_control": True},
+ {"rank": 6, "token_str": "[/AVAILABLE_TOOLS]", "is_control": True},
+ {"rank": 7, "token_str": "[TOOL_RESULTS]", "is_control": True},
+ {"rank": 8, "token_str": "[/TOOL_RESULTS]", "is_control": True},
+ {"rank": 9, "token_str": "[TOOL_CALLS]", "is_control": True},
+ {"rank": 10, "token_str": "[IMG]", "is_control": True},
+ {"rank": 11, "token_str": "", "is_control": True},
+ {"rank": 12, "token_str": "[IMG_BREAK]", "is_control": True},
+ {"rank": 13, "token_str": "[IMG_END]", "is_control": True},
+ {"rank": 14, "token_str": "[PREFIX]", "is_control": True},
+ {"rank": 15, "token_str": "[MIDDLE]", "is_control": True},
+ {"rank": 16, "token_str": "[SUFFIX]", "is_control": True},
+ {"rank": 17, "token_str": "[SYSTEM_PROMPT]", "is_control": True},
+ {"rank": 18, "token_str": "[/SYSTEM_PROMPT]", "is_control": True},
+ {"rank": 19, "token_str": "[TOOL_CONTENT]", "is_control": True},
+]
+
+_FAKE_TEKKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+
+# 256 byte-level BPE tokens + 20 special tokens = full single-byte coverage.
+_FULL_BYTE_VOCAB = 256 + _NUM_SPECIAL_TOKENS
+
+# Diverse test strings used across all test classes.
+_TEST_STRINGS = [
+ "Hello, world!",
+ "Bonjour le monde!",
+ "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
+ "The quick brown fox jumps over the lazy dog.",
+ "๐ Unicode: cafรฉ, naรฏve, rรฉsumรฉ",
+ " Multiple spaces and\ttabs\nand\nnewlines",
+ "12345 + 67890 = 80235",
+ "!@#$%^&*()",
+ " leading and trailing ",
+ "MiXeD CaSe TeXt",
+ "a",
+]
+
+_MINISTRAL_REPO = "mistralai/Ministral-3-3B-Instruct-2512"
+
+
+def _build_fake_tekken_json(
+ directory: Path,
+ vocab_size: int = _FULL_BYTE_VOCAB,
+ image_config: dict | None = None,
+) -> Path:
+ """Build a minimal tekken.json for testing."""
+ num_bpe = vocab_size - _NUM_SPECIAL_TOKENS
+
+ vocab_list: list[dict] = []
+ for rank in range(num_bpe):
+ raw_byte = bytes([rank % 256])
+ vocab_list.append(
+ {
+ "rank": rank,
+ "token_bytes": base64.b64encode(raw_byte).decode("ascii"),
+ "token_str": None,
+ }
+ )
+
+ tekken_data: dict = {
+ "vocab": vocab_list,
+ "special_tokens": _FAKE_TEKKEN_SPECIAL_TOKENS,
+ "config": {
+ "pattern": _FAKE_TEKKEN_PATTERN,
+ "num_vocab_tokens": num_bpe,
+ "default_vocab_size": vocab_size,
+ "default_num_special_tokens": _NUM_SPECIAL_TOKENS,
+ "version": "v3",
+ },
+ "version": 1,
+ "type": "tekken",
+ }
+
+ if image_config is not None:
+ tekken_data["image"] = image_config
+
+ output_path = directory / "tekken.json"
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(tekken_data, f, ensure_ascii=False)
+
+ return output_path
+
+
+class TestMistralConverter(unittest.TestCase):
+ """Unit tests for MistralConverter using a synthetic tekken.json."""
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._tmp_dir = tempfile.TemporaryDirectory()
+ cls._tekken_path = _build_fake_tekken_json(Path(cls._tmp_dir.name))
+ cls._converter = MistralConverter.from_tekken_file(str(cls._tekken_path))
+ cls._tokenizer = cls._converter.converted()
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ cls._tmp_dir.cleanup()
+
+ def test_from_tekken_file_sets_precomputed_fields(self):
+ self.assertIsNotNone(self._converter._precomputed_vocab)
+ self.assertIsNotNone(self._converter._precomputed_merges)
+ self.assertIsNotNone(self._converter._tekken_metadata)
+
+ def test_converted_produces_working_tokenizer(self):
+ ids = self._tokenizer.encode("a b c").ids
+ self.assertIsInstance(ids, list)
+ self.assertGreater(len(ids), 0)
+
+ def test_roundtrip_encode_decode(self):
+ for text in ["hello world", "abc 123", "test"]:
+ encoded = self._tokenizer.encode(text)
+ decoded = self._tokenizer.decode(encoded.ids)
+ self.assertEqual(decoded, text, f"Roundtrip failed for {text!r}")
+
+ def test_special_tokens_in_vocab(self):
+ vocab = self._tokenizer.get_vocab()
+ for entry in _FAKE_TEKKEN_SPECIAL_TOKENS:
+ self.assertIn(entry["token_str"], vocab, f"Special token {entry['token_str']!r} missing")
+
+ def test_vocab_size(self):
+ self.assertEqual(self._tokenizer.get_vocab_size(), _FULL_BYTE_VOCAB)
+
+ def test_tekken_metadata_content(self):
+ metadata = self._converter.tekken_metadata
+ self.assertEqual(metadata["version"], 1)
+ self.assertEqual(metadata["type"], "tekken")
+ self.assertIn("config", metadata)
+ self.assertEqual(metadata["config"]["version"], "v3")
+ self.assertNotIn("vocab", metadata, "Metadata should not contain the vocab itself")
+
+ def test_tekken_metadata_raises_without_from_file(self):
+ converter = MistralConverter()
+ with self.assertRaises(AttributeError):
+ _ = converter.tekken_metadata
+
+ def test_special_tokens_assigned_by_rank_not_list_order(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_path = Path(tmp_dir)
+ shuffled_specials = list(reversed(_FAKE_TEKKEN_SPECIAL_TOKENS))
+ num_bpe = _FULL_BYTE_VOCAB - _NUM_SPECIAL_TOKENS
+ vocab_list = [
+ {
+ "rank": rank,
+ "token_bytes": base64.b64encode(bytes([rank % 256])).decode("ascii"),
+ "token_str": None,
+ }
+ for rank in range(num_bpe)
+ ]
+ tekken_data = {
+ "vocab": vocab_list,
+ "special_tokens": shuffled_specials,
+ "config": {"pattern": _FAKE_TEKKEN_PATTERN},
+ "version": 1,
+ "type": "tekken",
+ }
+ tekken_path = tmp_path / "tekken.json"
+ with open(tekken_path, "w", encoding="utf-8") as f:
+ json.dump(tekken_data, f, ensure_ascii=False)
+
+ converter = MistralConverter.from_tekken_file(str(tekken_path))
+
+ for entry in _FAKE_TEKKEN_SPECIAL_TOKENS:
+ self.assertEqual(
+ converter._precomputed_vocab[entry["token_str"]],
+ entry["rank"],
+ f"Special token {entry['token_str']!r} got wrong id",
+ )
+
+
+@require_mistral_common
+class TestMistralConverterVsCommonBackend(unittest.TestCase):
+ """Compare MistralConverter raw encoding/decoding with MistralCommonBackend on a synthetic tekken.json.
+
+ MistralConverter.converted() does NOT add BOS/EOS โ that is the wrapper's job.
+ All comparisons use add_special_tokens=False on MistralCommonBackend.
+ """
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._tmp_dir = tempfile.TemporaryDirectory()
+ tekken_path = _build_fake_tekken_json(Path(cls._tmp_dir.name))
+
+ converter = MistralConverter.from_tekken_file(str(tekken_path))
+ cls.hf_tokenizer = converter.converted()
+ cls.mc_tokenizer = MistralCommonBackend(tokenizer_path=str(tekken_path))
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ cls._tmp_dir.cleanup()
+
+ def test_encode_matches(self) -> None:
+ for text in _TEST_STRINGS:
+ hf_ids = self.hf_tokenizer.encode(text).ids
+ mc_ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ self.assertEqual(hf_ids, mc_ids, f"Encoding mismatch for {text!r}")
+
+ def test_decode_matches(self) -> None:
+ for text in _TEST_STRINGS:
+ ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ hf_decoded = self.hf_tokenizer.decode(ids)
+ mc_decoded = self.mc_tokenizer.decode(ids, skip_special_tokens=True)
+ self.assertEqual(hf_decoded, mc_decoded, f"Decode mismatch for {text!r}")
+
+ def test_vocab_size(self) -> None:
+ self.assertEqual(self.hf_tokenizer.get_vocab_size(), self.mc_tokenizer.vocab_size)
+
+
+@require_mistral_common
+@slow
+class TestMistralConverterIntegration(unittest.TestCase):
+ """Integration tests with real tekken.json from mistralai/Ministral-3-3B-Instruct-2512.
+
+ MistralConverter.converted() returns a raw tokenizers.Tokenizer without
+ BOS/EOS injection. All encoding comparisons use add_special_tokens=False
+ on MistralCommonBackend to compare at the same abstraction level.
+ """
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ tekken_path = hf_hub_download(_MINISTRAL_REPO, "tekken.json")
+
+ converter = MistralConverter.from_tekken_file(tekken_path)
+ cls.hf_tokenizer = converter.converted()
+ cls.mc_tokenizer = MistralCommonBackend(tokenizer_path=tekken_path)
+
+ # โโ Vocabulary โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+ def test_vocab_size(self) -> None:
+ self.assertEqual(self.hf_tokenizer.get_vocab_size(), self.mc_tokenizer.vocab_size)
+
+ def test_full_vocab_decode_single_token_matches(self) -> None:
+ """Decoding every single token ID (skip_special_tokens=True) produces the same string."""
+ mismatches = []
+ for token_id in range(self.mc_tokenizer.vocab_size):
+ hf_decoded = self.hf_tokenizer.decode([token_id], skip_special_tokens=True)
+ mc_decoded = self.mc_tokenizer.decode([token_id], skip_special_tokens=True)
+ if hf_decoded != mc_decoded:
+ mismatches.append((token_id, hf_decoded, mc_decoded))
+ self.assertEqual(mismatches, [], f"Found {len(mismatches)} decode mismatches (first 10): {mismatches[:10]}")
+
+ def test_special_tokens_ids(self) -> None:
+ for token_str, attr in {"": "bos", "": "eos", "": "unk", "": "pad"}.items():
+ hf_id = self.hf_tokenizer.token_to_id(token_str)
+ mc_id = getattr(self.mc_tokenizer, f"{attr}_token_id")
+ self.assertIsNotNone(hf_id, f"HF tokenizer missing {token_str}")
+ self.assertIsNotNone(mc_id, f"MC tokenizer missing {attr}_token_id")
+ self.assertEqual(hf_id, mc_id, f"{token_str} ID mismatch: HF={hf_id} MC={mc_id}")
+
+ # โโ Encode โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+ def test_encode(self) -> None:
+ for text in _TEST_STRINGS:
+ hf_ids = self.hf_tokenizer.encode(text).ids
+ mc_ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ self.assertEqual(hf_ids, mc_ids, f"Encoding mismatch for {text!r}")
+
+ def test_encode_long_text(self) -> None:
+ long_text = "The quick brown fox jumps over the lazy dog. " * 100
+ hf_ids = self.hf_tokenizer.encode(long_text).ids
+ mc_ids = self.mc_tokenizer.encode(long_text, add_special_tokens=False)
+ self.assertEqual(hf_ids, mc_ids)
+ self.assertGreater(len(hf_ids), 100, "Long text should produce many tokens")
+
+ def test_encode_multilingual(self) -> None:
+ texts = [
+ "ๆฅๆฌ่ชใฎใในใ", # Japanese
+ "ะัะธะฒะตั ะผะธั", # Russian
+ "ู
ุฑุญุจุง ุจุงูุนุงูู
", # Arabic
+ "ไฝ ๅฅฝไธ็", # Chinese
+ "ํ๊ตญ์ด ํ
์คํธ", # Korean
+ "รoรฑo espaรฑol", # Spanish with diacritics
+ "ฮฮปฮปฮทฮฝฮนฮบฮฌ", # Greek
+ ]
+ for text in texts:
+ hf_ids = self.hf_tokenizer.encode(text).ids
+ mc_ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ self.assertEqual(hf_ids, mc_ids, f"Multilingual encoding mismatch for {text!r}")
+
+ def test_encode_code_snippets(self) -> None:
+ snippets = [
+ "import torch\nmodel = torch.nn.Linear(10, 20)",
+ "for i in range(100):\n print(f'{i=}')",
+ "class Foo:\n def __init__(self):\n self.x = 42",
+ "// C++ comment\nint main() { return 0; }",
+ "SELECT * FROM users WHERE id = 1;",
+ '{"key": "value", "nested": {"a": [1, 2, 3]}}',
+ ]
+ for text in snippets:
+ hf_ids = self.hf_tokenizer.encode(text).ids
+ mc_ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ self.assertEqual(hf_ids, mc_ids, f"Code encoding mismatch for {text!r}")
+
+ # โโ Decode โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+ def test_decode(self) -> None:
+ """Decode token IDs (no special tokens) โ both backends produce the same string."""
+ for text in _TEST_STRINGS:
+ ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ hf_decoded = self.hf_tokenizer.decode(ids)
+ mc_decoded = self.mc_tokenizer.decode(ids, skip_special_tokens=True)
+ self.assertEqual(hf_decoded, mc_decoded, f"Decode mismatch for {text!r}")
+
+ def test_decode_with_special_token_ids(self) -> None:
+ """Decode sequences that contain BOS/EOS IDs โ skip_special_tokens strips them equally."""
+ bos_id = self.hf_tokenizer.token_to_id("")
+ eos_id = self.hf_tokenizer.token_to_id("")
+ for text in _TEST_STRINGS:
+ ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ ids_with_special = [bos_id] + ids + [eos_id]
+
+ hf_decoded = self.hf_tokenizer.decode(ids_with_special, skip_special_tokens=True)
+ mc_decoded = self.mc_tokenizer.decode(ids_with_special, skip_special_tokens=True)
+ self.assertEqual(hf_decoded, mc_decoded, f"Decode skip BOS+EOS mismatch for {text!r}")
+
+ def test_encode_decode_roundtrip(self) -> None:
+ """Encode then decode should recover the original text in both backends."""
+ for text in _TEST_STRINGS:
+ if not text:
+ continue
+ hf_ids = self.hf_tokenizer.encode(text).ids
+ hf_roundtrip = self.hf_tokenizer.decode(hf_ids)
+ mc_roundtrip = self.mc_tokenizer.decode(hf_ids, skip_special_tokens=True)
+ self.assertEqual(hf_roundtrip, text, f"HF roundtrip failed for {text!r}")
+ self.assertEqual(mc_roundtrip, text, f"MC roundtrip failed for {text!r}")
+
+ # โโ Token-level โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+ def test_per_token_decode_matches(self) -> None:
+ """Decoding each token individually should produce the same string in both backends."""
+ for text in _TEST_STRINGS:
+ ids = self.mc_tokenizer.encode(text, add_special_tokens=False)
+ if not ids:
+ continue
+ for token_id in ids:
+ hf_decoded = self.hf_tokenizer.decode([token_id], skip_special_tokens=True)
+ mc_decoded = self.mc_tokenizer.decode([token_id], skip_special_tokens=True)
+ self.assertEqual(hf_decoded, mc_decoded, f"Per-token decode mismatch for id={token_id} in {text!r}")
+
+
+if __name__ == "__main__":
+ unittest.main()