Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,29 +241,37 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: str | list[
if isinstance(stop_strings, str):
stop_strings = [stop_strings]
self.stop_strings: tuple[str, ...] = tuple(stop_strings)
self._byte_level = self._is_byte_level_tokenizer(tokenizer)
# ByteLevel tokens may contain only part of a UTF-8 character, so match them before Unicode decoding.
self._matching_stop_strings = (
tuple(stop_string.encode("utf-8") for stop_string in stop_strings)
if self._byte_level
else self.stop_strings
)
vocab = tokenizer.get_vocab()
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
token_list, token_indices, tokenizer
)

self.maximum_token_len = max(len(stop_string) for stop_string in self.stop_strings)
self.maximum_token_len = max(len(stop_string) for stop_string in self._matching_stop_strings)
self.num_stop_strings = len(self.stop_strings)
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
self.target_lens = torch.tensor(
[len(stop_string) for stop_string in self._matching_stop_strings], dtype=torch.int32
)

def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer):
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE:
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
(token_list, token_indices, self.stop_strings)
]
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings))
cache_key = (token_list, token_indices, self._matching_stop_strings)
if cache_key in STOP_STRING_EMBEDDING_CACHE:
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[cache_key]
STOP_STRING_EMBEDDING_CACHE.move_to_end(cache_key)
else:
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer, byte_level=self._byte_level)
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
clean_token_list, clean_token_indices, self.stop_strings
clean_token_list, clean_token_indices, self._matching_stop_strings
)
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = (
STOP_STRING_EMBEDDING_CACHE[cache_key] = (
embedding_vec,
max_valid_positions,
max_valid_end_lens,
Expand All @@ -273,22 +281,55 @@ def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer
return embedding_vec, max_valid_positions, max_valid_end_lens

@staticmethod
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
def _is_byte_level_tokenizer(tokenizer):
decoder = getattr(getattr(tokenizer, "backend_tokenizer", None), "decoder", None)
if decoder is None:
return False
decoder_state = getattr(decoder, "__getstate__", lambda: None)()
return decoder.__class__.__name__ == "ByteLevel" or (
isinstance(decoder_state, bytes) and b'"ByteLevel"' in decoder_state
)

@staticmethod
def _byte_level_decoder():
byte_values = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
unicode_values = byte_values[:]
next_unicode = 2**8
for byte in range(2**8):
if byte not in byte_values:
byte_values.append(byte)
unicode_values.append(next_unicode)
next_unicode += 1
return {chr(char): byte for byte, char in zip(byte_values, unicode_values)}

@staticmethod
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef", byte_level=False):
"""
This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). ByteLevel
vocabularies are converted directly to bytes so incomplete UTF-8 fragments remain matchable.
"""
vocab = tokenizer.get_vocab()
clean_token_list = []
clean_token_indices = []
byte_decoder = StopStringCriteria._byte_level_decoder() if byte_level else None
sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
for token, token_idx in vocab.items():
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
if byte_decoder is not None and all(char in byte_decoder for char in token):
token_string = bytes(byte_decoder[char] for char in token)
else:
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
if byte_decoder is not None:
token_string = token_string.encode("utf-8")
clean_token_list.append(token_string)
clean_token_indices.append(token_idx)
return tuple(clean_token_list), tuple(clean_token_indices)
Expand Down
16 changes: 16 additions & 0 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,22 @@ def test_stop_string_criteria(self):
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))

def test_stop_string_criteria_byte_fragments(self):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
cases = [
("대화 끝", "끝", True),
("작업 완료", "완료", True),
("응답 종료", "종료", True),
("结束", "结束", True),
("끝 대화", "끝", False),
("완료 후속", "완료", False),
]

for text, stop_string, expected in cases:
input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=[stop_string])
self.assertEqual(bool(criteria(input_ids, scores=None)[0]), expected)

def test_stop_string_criteria_vocab_size_mismatch(self):
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
Expand Down