1+ import json
12import time
23import warnings
34from abc import ABC
@@ -241,29 +242,41 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: str | list[
241242 if isinstance (stop_strings , str ):
242243 stop_strings = [stop_strings ]
243244 self .stop_strings : tuple [str , ...] = tuple (stop_strings )
245+ self ._stop_string_matching_mode = self ._get_stop_string_matching_mode (tokenizer )
246+ self ._stop_strings_for_matching = self ._get_stop_strings_for_matching (
247+ self .stop_strings , self ._stop_string_matching_mode
248+ )
244249 vocab = tokenizer .get_vocab ()
245250 token_list , token_indices = tuple (vocab .keys ()), tuple (vocab .values ())
246251 self .embedding_vec , self .max_valid_positions , self .max_valid_end_lens = self .clean_and_embed_tokens_with_cache (
247252 token_list , token_indices , tokenizer
248253 )
249254
250- self .maximum_token_len = max (len (stop_string ) for stop_string in self .stop_strings )
255+ self .maximum_token_len = max (len (stop_string ) for stop_string in self ._stop_strings_for_matching )
251256 self .num_stop_strings = len (self .stop_strings )
252- self .target_lens = torch .tensor ([len (stop_string ) for stop_string in stop_strings ], dtype = torch .int32 )
257+ self .target_lens = torch .tensor (
258+ [len (stop_string ) for stop_string in self ._stop_strings_for_matching ], dtype = torch .int32
259+ )
253260
254261 def clean_and_embed_tokens_with_cache (self , token_list , token_indices , tokenizer ):
255262 # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
256- if (token_list , token_indices , self .stop_strings ) in STOP_STRING_EMBEDDING_CACHE :
257- embedding_vec , max_valid_positions , max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE [
258- (token_list , token_indices , self .stop_strings )
259- ]
260- STOP_STRING_EMBEDDING_CACHE .move_to_end ((token_list , token_indices , self .stop_strings ))
263+ cache_key = (
264+ token_list ,
265+ token_indices ,
266+ self ._stop_strings_for_matching ,
267+ self ._stop_string_matching_mode ,
268+ )
269+ if cache_key in STOP_STRING_EMBEDDING_CACHE :
270+ embedding_vec , max_valid_positions , max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE [cache_key ]
271+ STOP_STRING_EMBEDDING_CACHE .move_to_end (cache_key )
261272 else :
262- clean_token_list , clean_token_indices = self .clean_tokenizer_vocab (tokenizer )
273+ clean_token_list , clean_token_indices = self .clean_tokenizer_vocab (
274+ tokenizer , stop_string_matching_mode = self ._stop_string_matching_mode
275+ )
263276 embedding_vec , max_valid_positions , max_valid_end_lens = self ._stop_string_create_embedding_vec (
264- clean_token_list , clean_token_indices , self .stop_strings
277+ clean_token_list , clean_token_indices , self ._stop_strings_for_matching
265278 )
266- STOP_STRING_EMBEDDING_CACHE [( token_list , token_indices , self . stop_strings ) ] = (
279+ STOP_STRING_EMBEDDING_CACHE [cache_key ] = (
267280 embedding_vec ,
268281 max_valid_positions ,
269282 max_valid_end_lens ,
@@ -273,30 +286,101 @@ def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer
273286 return embedding_vec , max_valid_positions , max_valid_end_lens
274287
275288 @staticmethod
276- def clean_tokenizer_vocab (tokenizer , static_prefix = "abcdef" ):
289+ def _get_stop_string_matching_mode (tokenizer ):
290+ decoder = getattr (getattr (tokenizer , "backend_tokenizer" , None ), "decoder" , None )
291+ if decoder is None :
292+ return None
293+
294+ decoder_state = getattr (decoder , "__getstate__" , lambda : None )()
295+ if isinstance (decoder_state , str ):
296+ decoder_state = decoder_state .encode ()
297+ decoder_config = None
298+ if isinstance (decoder_state , bytes ):
299+ try :
300+ decoder_config = json .loads (decoder_state )
301+ except json .JSONDecodeError :
302+ decoder_config = None
303+
304+ # Some decoders do not expose a JSON state.
305+ if decoder .__class__ .__name__ == "ByteLevel" :
306+ return "byte_level"
307+ if decoder_config is not None :
308+ # Prefer explicit "<0xNN>" byte-fallback tokens if both markers appear.
309+ if StopStringCriteria ._decoder_has_type (decoder_config , "ByteFallback" ):
310+ return "byte_fallback"
311+ if StopStringCriteria ._decoder_has_type (decoder_config , "ByteLevel" ):
312+ return "byte_level"
313+ return None
314+
315+ @staticmethod
316+ def _decoder_has_type (decoder_config , decoder_type ):
317+ if isinstance (decoder_config , dict ):
318+ if decoder_config .get ("type" ) == decoder_type :
319+ return True
320+ return any (StopStringCriteria ._decoder_has_type (value , decoder_type ) for value in decoder_config .values ())
321+ if isinstance (decoder_config , list ):
322+ return any (StopStringCriteria ._decoder_has_type (value , decoder_type ) for value in decoder_config )
323+ return False
324+
325+ @staticmethod
326+ def _get_stop_strings_for_matching (stop_strings , matching_mode ):
327+ if matching_mode is None :
328+ return stop_strings
329+ return tuple (stop_string .encode ("utf-8" ) for stop_string in stop_strings )
330+
331+ @staticmethod
332+ def _byte_level_decoder ():
333+ from ..convert_slow_tokenizer import bytes_to_unicode
334+
335+ return {unicode_char : byte for byte , unicode_char in bytes_to_unicode ().items ()}
336+
337+ @staticmethod
338+ def _token_to_bytes (token , stop_string_matching_mode , byte_decoder ):
339+ if stop_string_matching_mode == "byte_level" :
340+ if byte_decoder is not None and all (char in byte_decoder for char in token ):
341+ return bytes (byte_decoder [char ] for char in token )
342+ return None
343+ if stop_string_matching_mode == "byte_fallback" :
344+ if (
345+ len (token ) == 6
346+ and token .startswith ("<0x" )
347+ and token .endswith (">" )
348+ and all (char in "0123456789abcdefABCDEF" for char in token [3 :5 ])
349+ ):
350+ return bytes ([int (token [3 :5 ], 16 )])
351+ return None
352+
353+ @staticmethod
354+ def clean_tokenizer_vocab (tokenizer , static_prefix = "abcdef" , stop_string_matching_mode = None ):
277355 """
278356 This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
279357 it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
280358 tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
281359 space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
282- it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
360+ it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). For
361+ byte-level vocabularies, incomplete UTF-8 fragments are kept as bytes until the stop string match is computed.
283362 """
284363 vocab = tokenizer .get_vocab ()
285364 clean_token_list = []
286365 clean_token_indices = []
366+ byte_decoder = StopStringCriteria ._byte_level_decoder () if stop_string_matching_mode == "byte_level" else None
287367 sentence_base = tokenizer (static_prefix , add_special_tokens = False )["input_ids" ]
288368 tokens_base = [tokenizer ._convert_id_to_token (tok ) for tok in sentence_base ]
289369 for token , token_idx in vocab .items ():
290- token_string = tokenizer .convert_tokens_to_string (tokens_base + [token ])
291- token_string = token_string [token_string .index (static_prefix ) + len (static_prefix ) :]
370+ token_string = StopStringCriteria ._token_to_bytes (token , stop_string_matching_mode , byte_decoder )
371+ if token_string is None :
372+ token_string = tokenizer .convert_tokens_to_string (tokens_base + [token ])
373+ token_string = token_string [token_string .index (static_prefix ) + len (static_prefix ) :]
374+ if stop_string_matching_mode is not None :
375+ token_string = token_string .encode ("utf-8" )
292376 clean_token_list .append (token_string )
293377 clean_token_indices .append (token_idx )
294378 return tuple (clean_token_list ), tuple (clean_token_indices )
295379
296380 @staticmethod
297381 def _stop_string_get_matching_positions (
298382 token_list , token_indices , stop_strings
299- ) -> tuple [dict [str , dict [str , list [int ]]], dict [str , dict [str , list [int ]]]]:
383+ ) -> tuple [dict [str | bytes , dict [str , list [int ]]], dict [str | bytes , dict [str , list [int ]]]]:
300384 """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
301385 validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
302386 token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
0 commit comments