Skip to content

Commit abaa043

Browse files
bad_words_ids no longer slow on mps (#39556)
* fix: bad_words_ids no longer slow on mps * fix: SequenceBiasLogitsProcessor slow `_prepare_bias_variables` method * fix: re-adding a deleted comment * fix: bug in no_bad_words_logits * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 6630c5b commit abaa043

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/transformers/generation/logits_process.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,10 +1222,16 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):
12221222
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
12231223
# with simpler logic.
12241224
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float, device=scores.device)
1225+
# Extract single-token sequences and their biases
1226+
single_token_ids = []
1227+
single_token_biases = []
12251228
for sequence_ids, bias in self.sequence_bias.items():
12261229
if len(sequence_ids) == 1:
1227-
self.length_1_bias[sequence_ids[-1]] = bias
1230+
single_token_ids.append(sequence_ids[0])
1231+
single_token_biases.append(bias)
12281232

1233+
if single_token_ids: # Only if we have any single-token sequences
1234+
self.length_1_bias[single_token_ids] = torch.tensor(single_token_biases, device=scores.device)
12291235
self.prepared_bias_variables = True
12301236

12311237
def _validate_arguments(self):
@@ -1340,10 +1346,10 @@ def __init__(
13401346
eos_token_id = [eos_token_id]
13411347
eos_token_id = torch.tensor(eos_token_id)
13421348

1349+
eos_token_id_list = eos_token_id.tolist() # convert to python list before
13431350
bad_words_ids = list(
1344-
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
1351+
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id_list), bad_words_ids)
13451352
)
1346-
13471353
# Forbidding a sequence is equivalent to setting its bias to -inf
13481354
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
13491355
super().__init__(sequence_bias=sequence_bias)

0 commit comments

Comments
 (0)