Skip to content

Commit c27e608

Browse files
IncheonkirinRocketknight1
authored andcommitted
Fix stop string matching for byte-fragment tokens
1 parent 6b86215 commit c27e608

2 files changed

Lines changed: 220 additions & 15 deletions

File tree

src/transformers/generation/stopping_criteria.py

Lines changed: 99 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import time
23
import warnings
34
from abc import ABC
@@ -241,29 +242,41 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: str | list[
241242
if isinstance(stop_strings, str):
242243
stop_strings = [stop_strings]
243244
self.stop_strings: tuple[str, ...] = tuple(stop_strings)
245+
self._stop_string_matching_mode = self._get_stop_string_matching_mode(tokenizer)
246+
self._stop_strings_for_matching = self._get_stop_strings_for_matching(
247+
self.stop_strings, self._stop_string_matching_mode
248+
)
244249
vocab = tokenizer.get_vocab()
245250
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
246251
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
247252
token_list, token_indices, tokenizer
248253
)
249254

250-
self.maximum_token_len = max(len(stop_string) for stop_string in self.stop_strings)
255+
self.maximum_token_len = max(len(stop_string) for stop_string in self._stop_strings_for_matching)
251256
self.num_stop_strings = len(self.stop_strings)
252-
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
257+
self.target_lens = torch.tensor(
258+
[len(stop_string) for stop_string in self._stop_strings_for_matching], dtype=torch.int32
259+
)
253260

254261
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer):
255262
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
256-
if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE:
257-
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
258-
(token_list, token_indices, self.stop_strings)
259-
]
260-
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings))
263+
cache_key = (
264+
token_list,
265+
token_indices,
266+
self._stop_strings_for_matching,
267+
self._stop_string_matching_mode,
268+
)
269+
if cache_key in STOP_STRING_EMBEDDING_CACHE:
270+
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[cache_key]
271+
STOP_STRING_EMBEDDING_CACHE.move_to_end(cache_key)
261272
else:
262-
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
273+
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(
274+
tokenizer, stop_string_matching_mode=self._stop_string_matching_mode
275+
)
263276
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
264-
clean_token_list, clean_token_indices, self.stop_strings
277+
clean_token_list, clean_token_indices, self._stop_strings_for_matching
265278
)
266-
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = (
279+
STOP_STRING_EMBEDDING_CACHE[cache_key] = (
267280
embedding_vec,
268281
max_valid_positions,
269282
max_valid_end_lens,
@@ -273,30 +286,101 @@ def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer
273286
return embedding_vec, max_valid_positions, max_valid_end_lens
274287

275288
@staticmethod
276-
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
289+
def _get_stop_string_matching_mode(tokenizer):
290+
decoder = getattr(getattr(tokenizer, "backend_tokenizer", None), "decoder", None)
291+
if decoder is None:
292+
return None
293+
294+
decoder_state = getattr(decoder, "__getstate__", lambda: None)()
295+
if isinstance(decoder_state, str):
296+
decoder_state = decoder_state.encode()
297+
decoder_config = None
298+
if isinstance(decoder_state, bytes):
299+
try:
300+
decoder_config = json.loads(decoder_state)
301+
except json.JSONDecodeError:
302+
decoder_config = None
303+
304+
# Some decoders do not expose a JSON state.
305+
if decoder.__class__.__name__ == "ByteLevel":
306+
return "byte_level"
307+
if decoder_config is not None:
308+
# Prefer explicit "<0xNN>" byte-fallback tokens if both markers appear.
309+
if StopStringCriteria._decoder_has_type(decoder_config, "ByteFallback"):
310+
return "byte_fallback"
311+
if StopStringCriteria._decoder_has_type(decoder_config, "ByteLevel"):
312+
return "byte_level"
313+
return None
314+
315+
@staticmethod
316+
def _decoder_has_type(decoder_config, decoder_type):
317+
if isinstance(decoder_config, dict):
318+
if decoder_config.get("type") == decoder_type:
319+
return True
320+
return any(StopStringCriteria._decoder_has_type(value, decoder_type) for value in decoder_config.values())
321+
if isinstance(decoder_config, list):
322+
return any(StopStringCriteria._decoder_has_type(value, decoder_type) for value in decoder_config)
323+
return False
324+
325+
@staticmethod
326+
def _get_stop_strings_for_matching(stop_strings, matching_mode):
327+
if matching_mode is None:
328+
return stop_strings
329+
return tuple(stop_string.encode("utf-8") for stop_string in stop_strings)
330+
331+
@staticmethod
332+
def _byte_level_decoder():
333+
from ..convert_slow_tokenizer import bytes_to_unicode
334+
335+
return {unicode_char: byte for byte, unicode_char in bytes_to_unicode().items()}
336+
337+
@staticmethod
338+
def _token_to_bytes(token, stop_string_matching_mode, byte_decoder):
339+
if stop_string_matching_mode == "byte_level":
340+
if byte_decoder is not None and all(char in byte_decoder for char in token):
341+
return bytes(byte_decoder[char] for char in token)
342+
return None
343+
if stop_string_matching_mode == "byte_fallback":
344+
if (
345+
len(token) == 6
346+
and token.startswith("<0x")
347+
and token.endswith(">")
348+
and all(char in "0123456789abcdefABCDEF" for char in token[3:5])
349+
):
350+
return bytes([int(token[3:5], 16)])
351+
return None
352+
353+
@staticmethod
354+
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef", stop_string_matching_mode=None):
277355
"""
278356
This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
279357
it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
280358
tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
281359
space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
282-
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
360+
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). For
361+
byte-level vocabularies, incomplete UTF-8 fragments are kept as bytes until the stop string match is computed.
283362
"""
284363
vocab = tokenizer.get_vocab()
285364
clean_token_list = []
286365
clean_token_indices = []
366+
byte_decoder = StopStringCriteria._byte_level_decoder() if stop_string_matching_mode == "byte_level" else None
287367
sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
288368
tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
289369
for token, token_idx in vocab.items():
290-
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
291-
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
370+
token_string = StopStringCriteria._token_to_bytes(token, stop_string_matching_mode, byte_decoder)
371+
if token_string is None:
372+
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
373+
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
374+
if stop_string_matching_mode is not None:
375+
token_string = token_string.encode("utf-8")
292376
clean_token_list.append(token_string)
293377
clean_token_indices.append(token_idx)
294378
return tuple(clean_token_list), tuple(clean_token_indices)
295379

296380
@staticmethod
297381
def _stop_string_get_matching_positions(
298382
token_list, token_indices, stop_strings
299-
) -> tuple[dict[str, dict[str, list[int]]], dict[str, dict[str, list[int]]]]:
383+
) -> tuple[dict[str | bytes, dict[str, list[int]]], dict[str | bytes, dict[str, list[int]]]]:
300384
"""This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
301385
validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
302386
token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters

tests/generation/test_stopping_criteria.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
StopStringCriteria,
3434
validate_stopping_criteria,
3535
)
36+
from transformers.generation.stopping_criteria import STOP_STRING_EMBEDDING_CACHE
3637

3738

3839
@require_torch
@@ -45,6 +46,14 @@ def _get_tensors(self, length):
4546
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
4647
return input_ids, scores
4748

49+
def _assert_isolated_token_decode_loses_stop_string(self, tokenizer, text, stop_string):
50+
input_ids = tokenizer(text, add_special_tokens=False)["input_ids"]
51+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
52+
isolated_text = "".join(tokenizer.convert_tokens_to_string([token]) for token in tokens)
53+
54+
self.assertTrue(tokenizer.decode(input_ids, skip_special_tokens=False).endswith(stop_string))
55+
self.assertNotIn(stop_string, isolated_text)
56+
4857
def test_list_criteria(self):
4958
input_ids, scores = self._get_tensors(5)
5059

@@ -175,6 +184,118 @@ def test_stop_string_criteria(self):
175184
for i in range(len(false_strings)):
176185
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
177186

187+
def test_stop_string_criteria_byte_fragments(self):
188+
STOP_STRING_EMBEDDING_CACHE.clear()
189+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
190+
self.assertEqual(StopStringCriteria._get_stop_string_matching_mode(tokenizer), "byte_level")
191+
self._assert_isolated_token_decode_loses_stop_string(tokenizer, "대화 끝", "끝")
192+
self._assert_isolated_token_decode_loses_stop_string(tokenizer, "작업 완료", "완료")
193+
194+
cases = [
195+
("대화 끝", "끝", True),
196+
("작업 완료", "완료", True),
197+
("대화 끝 다음", "끝", False),
198+
]
199+
200+
for text, stop_string, expected in cases:
201+
input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
202+
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=[stop_string])
203+
self.assertEqual(bool(criteria(input_ids, scores=None)[0]), expected)
204+
205+
def test_stop_string_criteria_byte_fallback_fragments(self):
206+
STOP_STRING_EMBEDDING_CACHE.clear()
207+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", use_fast=True)
208+
self.assertEqual(StopStringCriteria._get_stop_string_matching_mode(tokenizer), "byte_fallback")
209+
self._assert_isolated_token_decode_loses_stop_string(tokenizer, "대화 끝", "끝")
210+
self._assert_isolated_token_decode_loses_stop_string(tokenizer, "abc 끝!", "끝!")
211+
212+
cases = [
213+
("대화 끝", "끝", True),
214+
("abc 끝!", "끝!", True),
215+
("대화 끝 다음", "끝", False),
216+
("완료 후속", "완료", False),
217+
]
218+
219+
for text, stop_string, expected in cases:
220+
input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
221+
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=[stop_string])
222+
self.assertEqual(bool(criteria(input_ids, scores=None)[0]), expected)
223+
224+
def test_stop_string_criteria_byte_fragment_compile(self):
225+
if not hasattr(torch, "compile"):
226+
self.skipTest("torch.compile is not available")
227+
228+
STOP_STRING_EMBEDDING_CACHE.clear()
229+
cases = [
230+
("Qwen/Qwen2-0.5B-Instruct", "대화 끝", "끝"),
231+
("hf-internal-testing/llama-tokenizer", "abc 끝!", "끝!"),
232+
]
233+
for tokenizer_name, text, stop_string in cases:
234+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
235+
input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
236+
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=[stop_string])
237+
compiled_criteria = torch.compile(criteria, backend="eager", fullgraph=True)
238+
self.assertTrue(bool(compiled_criteria(input_ids, scores=None)[0]))
239+
240+
def test_stop_string_criteria_byte_level_ascii(self):
241+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
242+
self.assertEqual(StopStringCriteria._get_stop_string_matching_mode(tokenizer), "byte_level")
243+
244+
true_input_ids = tokenizer("the end", return_tensors="pt", add_special_tokens=False)["input_ids"]
245+
false_input_ids = tokenizer("end of", return_tensors="pt", add_special_tokens=False)["input_ids"]
246+
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["end"])
247+
self.assertTrue(bool(criteria(true_input_ids, scores=None)[0]))
248+
self.assertFalse(bool(criteria(false_input_ids, scores=None)[0]))
249+
250+
def test_stop_string_criteria_non_byte_level_tokenizer(self):
251+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
252+
self.assertIsNone(StopStringCriteria._get_stop_string_matching_mode(tokenizer))
253+
254+
true_input_ids = tokenizer("the end", return_tensors="pt", add_special_tokens=False)["input_ids"]
255+
false_input_ids = tokenizer("end of", return_tensors="pt", add_special_tokens=False)["input_ids"]
256+
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["end"])
257+
self.assertTrue(bool(criteria(true_input_ids, scores=None)[0]))
258+
self.assertFalse(bool(criteria(false_input_ids, scores=None)[0]))
259+
260+
def test_stop_string_matching_mode_helpers(self):
261+
class Decoder:
262+
def __init__(self, state):
263+
self.state = state
264+
265+
def __getstate__(self):
266+
return self.state
267+
268+
class BackendTokenizer:
269+
def __init__(self, decoder):
270+
self.decoder = decoder
271+
272+
class Tokenizer:
273+
def __init__(self, decoder):
274+
self.backend_tokenizer = BackendTokenizer(decoder)
275+
276+
self.assertEqual(
277+
StopStringCriteria._get_stop_string_matching_mode(
278+
Tokenizer(Decoder(b'{"type":"Sequence","decoders":[{"type":"ByteLevel"}]}'))
279+
),
280+
"byte_level",
281+
)
282+
self.assertEqual(
283+
StopStringCriteria._get_stop_string_matching_mode(
284+
Tokenizer(Decoder(b'{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"ByteLevel"}]}'))
285+
),
286+
"byte_fallback",
287+
)
288+
self.assertIsNone(
289+
StopStringCriteria._get_stop_string_matching_mode(
290+
Tokenizer(Decoder(b'{"type":"Replace","content":"ByteFallback"}'))
291+
)
292+
)
293+
294+
self.assertEqual(StopStringCriteria._token_to_bytes("<0xEB>", "byte_fallback", None), b"\xeb")
295+
self.assertEqual(StopStringCriteria._token_to_bytes("<0xeb>", "byte_fallback", None), b"\xeb")
296+
for token in ["<0x+1>", "<0xG1>", "<0x 1>", "<0x1>", "<0x100>", "<0xeb", "hello"]:
297+
self.assertIsNone(StopStringCriteria._token_to_bytes(token, "byte_fallback", None))
298+
178299
def test_stop_string_criteria_vocab_size_mismatch(self):
179300
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
180301
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

0 commit comments

Comments
 (0)