Skip to content

Commit 83b316a

Browse files
xq25478venkywonkahexiao.xqLinPoly
authored andcommitted
[feat]: support logit_bias (NVIDIA#5354)
Signed-off-by: xq25478 <[email protected]> Signed-off-by: Venky Ganesh <[email protected]> Signed-off-by: hexiao.xq <[email protected]> Co-authored-by: Venky Ganesh <[email protected]> Co-authored-by: hexiao.xq <[email protected]> Co-authored-by: Pengyun Lin <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent f63ece4 commit 83b316a

File tree

5 files changed

+132
-9
lines changed

5 files changed

+132
-9
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field, fields
5-
from typing import List, NamedTuple, Optional, Tuple, Union
5+
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
66

77
import torch
88
from pydantic import BaseModel
@@ -108,6 +108,55 @@ def __call__(
108108
pass # noqa
109109

110110

111+
class LogitBiasLogitsProcessor(LogitsProcessor):
112+
def __init__(self, logit_bias: Dict[str, float]) -> None:
113+
super().__init__()
114+
self.logit_bias = logit_bias
115+
self.tokens_to_adjust = self.process_logit_bias(logit_bias)
116+
if not self.tokens_to_adjust:
117+
raise ValueError("Empty logit_bias provided - no tokens to adjust")
118+
119+
def process_logit_bias(self, logit_bias: Dict[str, float]) -> Dict[int, float]:
120+
valid = {}
121+
invalid = {}
122+
123+
for k, v in logit_bias.items():
124+
try:
125+
token_id = int(k)
126+
valid[token_id] = v
127+
except (ValueError, TypeError):
128+
invalid[k] = v
129+
130+
if invalid:
131+
raise ValueError(
132+
f"Invalid token_ids in logit_bias: {list(invalid.keys())}. "
133+
f"All keys must be integers."
134+
)
135+
return valid
136+
137+
def __call__(
138+
self,
139+
req_id: int,
140+
logits: torch.Tensor,
141+
token_ids: List[List[int]],
142+
stream_ptr: Optional[int],
143+
client_id: Optional[int],
144+
) -> None:
145+
vocab_size = logits.size(-1)
146+
token_ids_list = list(self.tokens_to_adjust.keys())
147+
bias_values = torch.tensor(list(self.tokens_to_adjust.values()), device=logits.device)
148+
149+
invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size]
150+
if invalid_token_ids:
151+
raise ValueError(
152+
f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})"
153+
)
154+
155+
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
156+
with torch.cuda.stream(stream):
157+
logits[:, :, token_ids_list] += bias_values
158+
159+
111160
@dataclass(slots=True, kw_only=True)
112161
class AdditionalModelOutput:
113162
"""An additional output to gather from the model.

tensorrt_llm/serve/openai_protocol.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
1717
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
1818

19+
from ..sampling_params import LogitBiasLogitsProcessor
20+
1921

2022
class OpenAIBaseModel(BaseModel):
2123
# OpenAI API does not allow extra fields & allow to initialize by both alias and field name
@@ -248,6 +250,10 @@ def to_sampling_params(self) -> SamplingParams:
248250
self.response_format),
249251
detokenize=self.detokenize,
250252

253+
# logits_bias
254+
logits_processor=None if not self.logit_bias else
255+
LogitBiasLogitsProcessor(self.logit_bias),
256+
251257
# completion-extra-params
252258
add_special_tokens=self.add_special_tokens,
253259

@@ -539,6 +545,10 @@ def to_sampling_params(self) -> SamplingParams:
539545
guided_decoding=_response_format_to_guided_decoding_params(
540546
self.response_format),
541547

548+
# logits_bias
549+
logits_processor=None if not self.logit_bias else
550+
LogitBiasLogitsProcessor(self.logit_bias),
551+
542552
# chat-completion-extra-params
543553
add_special_tokens=self.add_special_tokens,
544554

@@ -574,13 +584,6 @@ def check_logprobs(cls, data):
574584
raise ValueError("top_logprobs is not supported")
575585
return data
576586

577-
@model_validator(mode="before")
578-
@classmethod
579-
def verify_logit_processor(cls, data):
580-
if data.get("logit_bias"):
581-
raise ValueError("logit bias is not supported")
582-
return data
583-
584587
@model_validator(mode="before")
585588
@classmethod
586589
def check_suffix(cls, data):

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ l0_a10:
2929
- test_e2e.py::test_openai_misc_example[pytorch]
3030
- test_e2e.py::test_openai_reasoning[pytorch]
3131
- test_e2e.py::test_openai_completions_example[pytorch]
32-
- test_e2e.py::test_openai_chat_example[pytorch]
32+
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
3333
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
3434
- condition:
3535
ranges:

tests/unittest/llmapi/apps/_test_openai_chat.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,41 @@ def test_stop_reason(client: openai.OpenAI, model_name: str, backend: str):
521521
)
522522
assert resp.choices[0].finish_reason == "stop"
523523
assert resp.choices[0].stop_reason == "two"
524+
525+
526+
@pytest.mark.asyncio
527+
async def test_chat_completion_with_logit_bias(async_client: openai.AsyncOpenAI,
528+
model_name: str):
529+
"""Test logit_bias in chat completions"""
530+
logit_bias = {
531+
"1000": 2.0,
532+
"2000": -2.0,
533+
}
534+
535+
chat_completion = await async_client.chat.completions.create(
536+
model=model_name,
537+
messages=[{
538+
"role": "user",
539+
"content": "Tell me a fact about Paris"
540+
}],
541+
max_tokens=20,
542+
logit_bias=logit_bias,
543+
temperature=0.0,
544+
)
545+
assert chat_completion.choices[0].message.content
546+
547+
548+
@pytest.mark.asyncio
549+
async def test_chat_completion_with_invalid_logit_bias(
550+
async_client: openai.AsyncOpenAI, model_name: str):
551+
"""Test with invalid token IDs (non-integer keys)"""
552+
with pytest.raises(openai.BadRequestError):
553+
await async_client.chat.completions.create(
554+
model=model_name,
555+
messages=[{
556+
"role": "user",
557+
"content": "Tell me a fact about Paris"
558+
}],
559+
logit_bias={"invalid_token": 1.0}, # Non-integer key
560+
max_tokens=5,
561+
)

tests/unittest/llmapi/apps/_test_openai_completions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,36 @@ async def test_completion_streaming(async_client: openai.AsyncOpenAI,
368368
tokens.extend(chunk.choices[0].token_ids)
369369

370370
assert tokens == single_output
371+
372+
373+
@pytest.mark.asyncio
374+
async def test_completion_with_logit_bias(async_client: openai.AsyncOpenAI,
375+
model_name: str):
376+
"""Test logit_bias with valid token IDs"""
377+
logit_bias = {
378+
"1000": 80,
379+
"2000": -80,
380+
}
381+
382+
completion = await async_client.completions.create(
383+
model=model_name,
384+
prompt="The capital of France is",
385+
max_tokens=10,
386+
logit_bias=logit_bias,
387+
temperature=0.0,
388+
)
389+
390+
assert completion.choices[0].text
391+
392+
393+
@pytest.mark.asyncio
394+
async def test_completion_with_invalid_logit_bias(
395+
async_client: openai.AsyncOpenAI, model_name: str):
396+
"""Test with invalid token IDs (non-integer keys)"""
397+
with pytest.raises(openai.BadRequestError):
398+
await async_client.completions.create(
399+
model=model_name,
400+
prompt="Hello world",
401+
logit_bias={"invalid_token": 1.0}, # Non-integer key
402+
max_tokens=5,
403+
)

0 commit comments

Comments
 (0)