@@ -111,12 +111,9 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
111
111
def __init__ (self , logit_bias : Dict [str , float ]) -> None :
112
112
super ().__init__ ()
113
113
self .logit_bias = logit_bias
114
- self .tokens_to_adjust = {}
115
- try :
116
- self .tokens_to_adjust = self .process_logit_bias (logit_bias )
117
- except ValueError as e :
118
- logger .error (e )
119
- raise
114
+ self .tokens_to_adjust = self .process_logit_bias (logit_bias )
115
+ if not self .tokens_to_adjust :
116
+ raise ValueError ("Empty logit_bias provided - no tokens to adjust" )
120
117
121
118
def process_logit_bias (self ,logit_bias : Dict [str , float ]) -> Dict [int , float ]:
122
119
valid = {}
@@ -140,25 +137,24 @@ def __call__(self, req_id: int, logits: torch.Tensor,
140
137
token_ids : List [List [int ]], stream_ptr : Optional [int ],
141
138
client_id : Optional [int ]) -> None :
142
139
143
- if self .tokens_to_adjust :
144
- vocab_size = logits .size (- 1 )
145
- token_ids_list = list (self .tokens_to_adjust .keys ())
146
- bias_values = torch .tensor (
147
- list (self .tokens_to_adjust .values ())
148
- )
149
-
150
- invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size ]
151
- if invalid_token_ids :
152
- raise ValueError (
153
- f"Token ID(s) { invalid_token_ids } exceed vocabulary size (vocab_size={ vocab_size } )"
154
- )
140
+ vocab_size = logits .size (- 1 )
141
+ token_ids_list = list (self .tokens_to_adjust .keys ())
142
+ bias_values = torch .tensor (
143
+ list (self .tokens_to_adjust .values ())
144
+ )
155
145
156
- stream = None if stream_ptr is None else torch .cuda .ExternalStream (stream_ptr )
157
- with torch .cuda .stream (stream ):
158
- logits [:, :, token_ids_list ] += bias_values
146
+ invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size ]
147
+ if invalid_token_ids :
148
+ raise ValueError (
149
+ f"Token ID(s) { invalid_token_ids } exceed vocabulary size (vocab_size={ vocab_size } )"
150
+ )
151
+
152
+ stream = None if stream_ptr is None else torch .cuda .ExternalStream (stream_ptr )
153
+ with torch .cuda .stream (stream ):
154
+ logits [:, :, token_ids_list ] += bias_values
159
155
160
- if stream is not None :
161
- stream .synchronize ()
156
+ if stream is not None :
157
+ stream .synchronize ()
162
158
163
159
@dataclass (slots = True , kw_only = True )
164
160
class AdditionalModelOutput :
0 commit comments