Skip to content

Commit 7196cd0

Browse files
authored
imp(logitbias) improve logitbias code
1 parent 147fb84 commit 7196cd0

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field, fields
5-
from typing import List, NamedTuple, Optional, Tuple, Union,Dict
5+
from typing import List, NamedTuple, Optional, Tuple, Union, Dict
66

77
import torch
88
from pydantic import BaseModel
@@ -112,25 +112,47 @@ def __init__(self, logit_bias: Dict[str, float]) -> None:
112112
super().__init__()
113113
self.logit_bias = logit_bias
114114
self.tokens_to_adjust = {}
115+
try:
116+
self.tokens_to_adjust = self.process_logit_bias(logit_bias)
117+
except ValueError as e:
118+
logger.error(e)
119+
raise
120+
121+
def process_logit_bias(self,logit_bias: Dict[str, float]) -> Dict[int, float]:
122+
valid = {}
123+
invalid = {}
124+
115125
for k, v in logit_bias.items():
116126
try:
117127
token_id = int(k)
118-
self.tokens_to_adjust[token_id] = v
128+
valid[token_id] = v
119129
except (ValueError, TypeError):
120-
continue
121-
130+
invalid[k] = v
131+
132+
if invalid:
133+
raise ValueError(
134+
f"Invalid token_ids in logit_bias: {list(invalid.keys())}. "
135+
f"All keys must be integers."
136+
)
137+
return valid
138+
122139
def __call__(self, req_id: int, logits: torch.Tensor,
123140
token_ids: List[List[int]], stream_ptr: Optional[int],
124141
client_id: Optional[int]) -> None:
125142

126143
if self.tokens_to_adjust:
144+
vocab_size = logits.size(-1)
127145
token_ids_list = list(self.tokens_to_adjust.keys())
128146
bias_values = torch.tensor(
129-
[self.tokens_to_adjust[token] for token in token_ids_list],
130-
device=logits.device,
131-
dtype=logits.dtype
147+
list(self.tokens_to_adjust.values())
132148
)
133-
149+
150+
invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size]
151+
if invalid_token_ids:
152+
raise ValueError(
153+
f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})"
154+
)
155+
134156
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
135157
with torch.cuda.stream(stream):
136158
logits[:, :, token_ids_list] += bias_values

0 commit comments

Comments
 (0)