@@ -1222,10 +1222,16 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):
1222
1222
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
1223
1223
# with simpler logic.
1224
1224
self .length_1_bias = torch .zeros ((vocabulary_size ,), dtype = torch .float , device = scores .device )
1225
+ # Extract single-token sequences and their biases
1226
+ single_token_ids = []
1227
+ single_token_biases = []
1225
1228
for sequence_ids , bias in self .sequence_bias .items ():
1226
1229
if len (sequence_ids ) == 1 :
1227
- self .length_1_bias [sequence_ids [- 1 ]] = bias
1230
+ single_token_ids .append (sequence_ids [0 ])
1231
+ single_token_biases .append (bias )
1228
1232
1233
+ if single_token_ids : # Only if we have any single-token sequences
1234
+ self .length_1_bias [single_token_ids ] = torch .tensor (single_token_biases , device = scores .device )
1229
1235
self .prepared_bias_variables = True
1230
1236
1231
1237
def _validate_arguments (self ):
@@ -1340,10 +1346,10 @@ def __init__(
1340
1346
eos_token_id = [eos_token_id ]
1341
1347
eos_token_id = torch .tensor (eos_token_id )
1342
1348
1349
+ eos_token_id_list = eos_token_id .tolist () # convert to python list before
1343
1350
bad_words_ids = list (
1344
- filter (lambda bad_token_seq : all (bad_token_seq != [i ] for i in eos_token_id ), bad_words_ids )
1351
+ filter (lambda bad_token_seq : all (bad_token_seq != [i ] for i in eos_token_id_list ), bad_words_ids )
1345
1352
)
1346
-
1347
1353
# Forbidding a sequence is equivalent to setting its bias to -inf
1348
1354
sequence_bias = {tuple (sequence ): float ("-inf" ) for sequence in bad_words_ids }
1349
1355
super ().__init__ (sequence_bias = sequence_bias )
0 commit comments