Skip to content

Commit 533f5b6

Browse files
committed
Update sampling_params.py
Signed-off-by: xq25478 <[email protected]>
1 parent 180a38a commit 533f5b6

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,9 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
111111
def __init__(self, logit_bias: Dict[str, float]) -> None:
112112
super().__init__()
113113
self.logit_bias = logit_bias
114-
self.tokens_to_adjust = {}
115-
try:
116-
self.tokens_to_adjust = self.process_logit_bias(logit_bias)
117-
except ValueError as e:
118-
logger.error(e)
119-
raise
114+
self.tokens_to_adjust = self.process_logit_bias(logit_bias)
115+
if not self.tokens_to_adjust:
116+
raise ValueError("Empty logit_bias provided - no tokens to adjust")
120117

121118
def process_logit_bias(self,logit_bias: Dict[str, float]) -> Dict[int, float]:
122119
valid = {}
@@ -140,25 +137,24 @@ def __call__(self, req_id: int, logits: torch.Tensor,
140137
token_ids: List[List[int]], stream_ptr: Optional[int],
141138
client_id: Optional[int]) -> None:
142139

143-
if self.tokens_to_adjust:
144-
vocab_size = logits.size(-1)
145-
token_ids_list = list(self.tokens_to_adjust.keys())
146-
bias_values = torch.tensor(
147-
list(self.tokens_to_adjust.values())
148-
)
149-
150-
invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size]
151-
if invalid_token_ids:
152-
raise ValueError(
153-
f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})"
154-
)
140+
vocab_size = logits.size(-1)
141+
token_ids_list = list(self.tokens_to_adjust.keys())
142+
bias_values = torch.tensor(
143+
list(self.tokens_to_adjust.values())
144+
)
155145

156-
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
157-
with torch.cuda.stream(stream):
158-
logits[:, :, token_ids_list] += bias_values
146+
invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size]
147+
if invalid_token_ids:
148+
raise ValueError(
149+
f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})"
150+
)
151+
152+
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
153+
with torch.cuda.stream(stream):
154+
logits[:, :, token_ids_list] += bias_values
159155

160-
if stream is not None:
161-
stream.synchronize()
156+
if stream is not None:
157+
stream.synchronize()
162158

163159
@dataclass(slots=True, kw_only=True)
164160
class AdditionalModelOutput:

0 commit comments

Comments
 (0)