Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 3 additions & 83 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@

import warnings
from collections.abc import Collection
from functools import lru_cache

from packaging import version
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
from tqdm import tqdm

from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
Expand Down Expand Up @@ -1973,86 +1971,6 @@ def converted(self) -> Tokenizer:
return tokenizer


class MistralConverter:
def __init__(
self,
vocab_file=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
**kwargs,
):
self.vocab_file = vocab_file
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = (
additional_special_tokens.keys()
if isinstance(additional_special_tokens, dict)
else additional_special_tokens
)

def extract_vocab_merges_from_model(self, tiktoken_url: str):
import base64
import json

with open(self.vocab_file, "r", encoding="utf-8") as f:
untyped = json.load(f)
self.pattern = untyped["config"]["pattern"]
self.additional_special_tokens = [
AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
]
bpe_ranks = untyped["vocab"]
byte_encoder = bytes_to_unicode()

@lru_cache
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])

merges = []
vocab = {}
for idx, token in enumerate(self.additional_special_tokens):
vocab[token.content] = idx
bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
rank_set = set(bpe_ranks)
token_to_rank = {token: rank for rank, token in enumerate(bpe_ranks)}
for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
vocab[token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (token_to_rank[x[0]], token_to_rank[x[1]]), reverse=False)
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges

def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer

def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()

tokenizer.add_tokens(self.additional_special_tokens)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

return tokenizer


SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
Expand Down Expand Up @@ -2133,9 +2051,11 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()
elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
from .integrations.mistral.tokenizer import MistralConverter

transformer_tokenizer.original_tokenizer = transformer_tokenizer
logger.info("Converting from Mistral tekken.json")
return MistralConverter(transformer_tokenizer.vocab_file).converted()
return MistralConverter.from_tekken_file(transformer_tokenizer.vocab_file).converted()
else:
try:
logger.info("Converting from Tiktoken")
Expand Down
121 changes: 0 additions & 121 deletions src/transformers/integrations/mistral.py

This file was deleted.

17 changes: 17 additions & 0 deletions src/transformers/integrations/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Mistral native format integration: tokenizer conversion utilities."""

from typing import TYPE_CHECKING

from ...utils import _LazyModule


_import_structure = {
"tokenizer": ["MistralConverter"],
}

if TYPE_CHECKING:
from .tokenizer import MistralConverter
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading
Loading