Skip to content

Commit 251e36b

Browse files
committed
[TRTLLM-7015] [feat] Enable prompt_logprobs in pytorch backend (NVIDIA#7580)
Signed-off-by: Venky Ganesh <[email protected]> Signed-off-by: Venky Ganesh <[email protected]>
1 parent a5421ab commit 251e36b

File tree

7 files changed

+195
-32
lines changed

7 files changed

+195
-32
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,17 @@ class LlmResponse:
261261
def has_error(self):
262262
return self.error_msg is not None
263263

264+
def clear_context_logits(self):
265+
"""Clear context logits from the response result.
266+
267+
This is used to drop context logits after prompt_logprobs have been computed
268+
when the user didn't explicitly request them.
269+
"""
270+
if self.result and hasattr(self.result, '_py_result'):
271+
py_result = self.result._py_result
272+
if hasattr(py_result, '_context_logits'):
273+
py_result._context_logits = None
274+
264275

265276
class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
266277
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
@@ -350,10 +361,36 @@ def __init__(
350361
def is_generation_only_request(self):
351362
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
352363

353-
def create_response(
354-
self,
355-
use_fast_logits=False,
356-
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
364+
def create_response(self,
365+
use_fast_logits=False,
366+
mpi_world_rank=0) -> LlmResponse | None:
367+
"""Create an LlmResponse from the current request state.
368+
369+
This method generates a response containing the request's execution results,
370+
including generated tokens, logits, and completion status. It wraps the
371+
parent class's serialized result in a PyTorch-specific LlmResponse object.
372+
373+
Args:
374+
use_fast_logits (bool, optional, default=False): Only applicable for TRT-backend with speculative decoding enabled. When returning generation logits under speculative decoding,
375+
`use_fast_logits=True` replaces tensor payloads with tiny metadata so the target pulls logits
376+
directly (zero-copy/IPC), reducing overhead; ignored on PyTorch.
377+
mpi_world_rank (int, optional, default=0): Only applicable for TRT-backend, with speculative decoding
378+
enabled, and `use_fast_logits=True`. Contains the MPI world rank of the process containing the draft
379+
model, that produces the generation logits. This helps transfer logits from the draft model to the
380+
target model without going through the serialization/transport path.
381+
382+
Returns:
383+
LlmResponse | None: An LlmResponse object containing the request results
384+
if there is valid output, otherwise None.
385+
The response includes:
386+
- request_id: The request identifier (parent ID for child requests)
387+
- result: LlmResult wrapping both serialized and PyTorch-specific results
388+
- client_id: The client identifier for request routing
389+
390+
Note:
391+
Returns None if the serialized result is empty (len(result) == 0),
392+
indicating no output was generated for this request iteration.
393+
"""
357394
result, is_final = super().create_serialized_result(
358395
use_fast_logits, mpi_world_rank)
359396
return LlmResponse(

tensorrt_llm/executor/result.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,20 +219,36 @@ def _handle_sequence(self,
219219
if response_tensors.cum_log_probs is not None:
220220
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx]
221221

222-
if logprobs_result:
222+
# prompt logprobs handling
223+
if logprobs_result and logprobs_result.prompt is not None: # both backends
223224
output.prompt_logprobs = logprobs_result.prompt
224-
output.logprobs = logprobs_result.generation
225-
226-
if response_tensors.log_probs is not None:
225+
# generation logprobs handling (provenance varies by backend)
226+
if logprobs_result and logprobs_result.generation is not None: # TRT backend
227+
# update logprobs from ResponseWrapper (TRT top logprobs WAR)
228+
output._last_logprobs_len = len(output.logprobs)
229+
output.logprobs += logprobs_result.generation
230+
elif response_tensors.log_probs is not None: # PyTorch backend
231+
# handle logprobs directly from response tensors given by sampler
227232
output._last_logprobs_len = len(output.logprobs)
228-
output.logprobs = response_tensors.log_probs[src_idx]
233+
# In streaming mode, since out-of-order responses are not possible,
234+
# each streamed response_tensors.log_probs[src_idx]
235+
# contains a streamwise monotonically growing list of logprobs.
236+
# so we need to accumulate only the new ones unique to that particular streamed response
237+
assert output._last_logprobs_len <= len(
238+
response_tensors.log_probs[src_idx]
239+
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
240+
f"{len(response_tensors.log_probs[src_idx])})")
241+
output.logprobs += response_tensors.log_probs[src_idx][
242+
output._last_logprobs_len:]
229243
# overcome some WAR in the cpp executor
230244
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
245+
# Check if logprobs is a list (not a dict or other structure)
231246
if len(output.logprobs) > output.length:
232247
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
233248
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
234249
output.logprobs = output.logprobs[:output.length]
235250
assert len(output.logprobs) == output.length
251+
236252
if response_tensors.generation_logits is not None:
237253
output.generation_logits = response_tensors.generation_logits[
238254
src_idx, :output.length]
@@ -636,7 +652,12 @@ def compute_logprobs(
636652
output_token_ids: Optional[list[int]],
637653
) -> LogProbsResult:
638654
"""
639-
Compute top-K logprobs and ranks for each token position.
655+
Compute top-K logprobs from logits when engine doesn't provide them directly.
656+
657+
Used for post-processing logits into logprobs.
658+
- Prompt logprobs (from context_logits): always used.
659+
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT).
660+
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here.
640661
641662
Returns:
642663
LogProbsResult, a NamedTuple containing:

tensorrt_llm/executor/worker.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from tensorrt_llm.logger import logger
1616

17+
from .._torch.pyexecutor.llm_request import LlmResponse
1718
from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
1819
mpi_comm, mpi_rank, nvtx_range_debug)
1920
from ..bindings import executor as tllm
@@ -972,24 +973,60 @@ def _get_params_for_first_rsp(
972973
return None, None
973974

974975

976+
def _compute_pytorch_prompt_logprobs(
977+
generation_result: GenerationResult,
978+
response: LlmResponse) -> Optional[LogProbsResult]:
979+
"""Compute prompt logprobs for PyTorch backend (cached when streaming) """
980+
logprob_params = generation_result._logprob_params # should be present and non None
981+
assert logprob_params is not None
982+
if generation_result._streaming:
983+
cached = getattr(generation_result, '_cached_prompt_logprobs', None)
984+
if cached is not None:
985+
return LogProbsResult(
986+
prompt=cached, generation=None
987+
) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler.
988+
context_logits = response.result.context_logits
989+
assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested."
990+
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None,
991+
context_logits, None, None)
992+
if generation_result._streaming:
993+
generation_result._cached_prompt_logprobs = logprobs_result.prompt
994+
995+
return logprobs_result
996+
997+
975998
def _get_logprobs(worker,
976-
response: tllm.Response,
999+
response: Union[tllm.Response, LlmResponse],
9771000
is_pytorch_backend=False) -> Optional[LogProbsResult]:
978-
"""Compute logprob and prompt logprob and clear out logits if applicable.
1001+
"""Compute logprobs from response logits when needed.
1002+
1003+
Logprobs provenance varies by backend:
1004+
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
1005+
- TRT: Both prompt and generation logprobs computed here from logits
9791006
"""
980-
if is_pytorch_backend:
981-
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
982-
# In the PyTorch backend, logprobs are already computed during runtime if requested.
983-
return None
9841007

9851008
logprobs_result = None
9861009
generation_result = worker._results.get(response.client_id, None)
9871010

9881011
if not generation_result:
989-
return
1012+
return None
9901013

9911014
logprob_params = getattr(generation_result, "_logprob_params", None)
9921015
if logprob_params:
1016+
if is_pytorch_backend:
1017+
if not logprob_params.prompt_logprobs:
1018+
# PyTorch: generation logprobs computed in sampler, no post-processing needed
1019+
return None
1020+
else:
1021+
logprobs_result = _compute_pytorch_prompt_logprobs(
1022+
generation_result, response)
1023+
1024+
if logprob_params.drop_context_logits:
1025+
response.clear_context_logits()
1026+
1027+
return logprobs_result
1028+
1029+
# TRT backend: compute both prompt and generation logprobs from logits
9931030
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
9941031
logprob_params.logprobs,
9951032
response.result.context_logits,

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -565,12 +565,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
565565
is_gen_only: bool) -> None:
566566

567567
if self.args.backend in ["pytorch", "_autodeploy"]:
568-
# TODO: remove these checks after PyTorch backend
569-
# fully support TopK prompt and generation logprobs.
570-
if sampling_params.prompt_logprobs:
571-
raise ValueError(
572-
f"`prompt_logprobs` in sampling_params is not supported in the PyTorch backend yet. Received `prompt_logprobs={sampling_params.prompt_logprobs}`. Please unset this field."
573-
)
574568
if sampling_params.logprobs and sampling_params.logprobs > 1:
575569
raise ValueError(
576570
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."

tensorrt_llm/sampling_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel
99

1010
from tensorrt_llm.bindings import executor as tllme
11+
from tensorrt_llm.logger import logger
1112

1213

1314
@dataclass(slots=True, kw_only=True)
@@ -446,6 +447,20 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo
446447

447448
if is_pytorch_backend:
448449
config_kwargs["return_log_probs"] = bool(self.logprobs)
450+
if self.prompt_logprobs and not self.return_context_logits:
451+
logger.info(
452+
"Since prompt_logprobs is requested but return_context_logits is False, "
453+
"internally enabling context logits for prompt logprobs computation. "
454+
"context logits will be dropped after computation as the user didn't explicitly request them."
455+
)
456+
# TODO(venky): Find a more elegant way to do this.
457+
# NOTE: This is an internal hack, so we can entirely avoid introducing
458+
# `prompt_logprobs` into the executor bindings and further into
459+
# model engine / sampler.
460+
# This is because, prompt_logprobs is a derived quantity from
461+
# context logits, and the capability to post-compute it
462+
# already exists in the worker. (see _get_logprobs in worker.py)
463+
config_kwargs["return_context_logits"] = True
449464
else:
450465
config_kwargs["return_log_probs"] = self._return_log_probs
451466

tests/unittest/llmapi/test_llm.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,14 +1803,20 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int],
18031803
backend=None):
18041804
LLM_CLASS = LLM
18051805
llm_args_extra = {}
1806+
kv_cache_args_extra = {}
18061807
if backend in ["pytorch", "autodeploy"]:
18071808
LLM_CLASS = LLM_torch
1809+
if streaming:
1810+
# need this so that context_logits / prompt_logprobs are not dropped
1811+
# in the 2nd reuse of llm.generate() in streaming mode
1812+
kv_cache_args_extra["enable_block_reuse"] = False
18081813
else:
18091814
llm_args_extra["fast_build"] = True
18101815

18111816
llm = LLM_CLASS(
18121817
llama_model_path,
1813-
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
1818+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4,
1819+
**kv_cache_args_extra),
18141820
build_config=BuildConfig(gather_context_logits=True),
18151821
tensor_parallel_size=tp_size,
18161822
gather_generation_logits=True,
@@ -1863,7 +1869,7 @@ async def task(id: int, prompt: str):
18631869
async for output in llm.generate_async(prompt,
18641870
sampling_params,
18651871
streaming=True):
1866-
logprobs_result_streaming += output.outputs[0].logprobs
1872+
logprobs_result_streaming += output.outputs[0].logprobs_diff
18671873

18681874
# comparing streaming logprobs result to non-streaming
18691875
assert logprobs_result_streaming == logprobs_result
@@ -1878,15 +1884,24 @@ async def main():
18781884

18791885
@force_ampere
18801886
@pytest.mark.parametrize(
1881-
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits",
1882-
[(2, None, True, False), (None, 2, False, False)])
1887+
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",
1888+
[
1889+
# TRT backend test cases
1890+
(2, None, True, False, "trt"), # prompt_logprobs with context_logits
1891+
(None, 2, False, False, "trt"), # generation logprobs only (top-2)
1892+
(2, None, False, False,
1893+
"trt"), # prompt_logprobs without context_logits
1894+
(None, None, False, False, "trt"), # no logprobs at all
1895+
])
18831896
def test_llm_return_logprobs(prompt_logprobs: Optional[int],
18841897
logprobs: Optional[int],
18851898
return_context_logits: bool,
1886-
return_generation_logits: bool):
1887-
llm_return_logprobs_test_harness(prompt_logprobs, logprobs,
1899+
return_generation_logits: bool, backend: str):
1900+
llm_return_logprobs_test_harness(prompt_logprobs,
1901+
logprobs,
18881902
return_context_logits,
1889-
return_generation_logits)
1903+
return_generation_logits,
1904+
backend=backend)
18901905

18911906

18921907
@force_ampere

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import random
22
from contextlib import contextmanager, nullcontext
3+
from typing import Optional
34

45
import pytest
56

@@ -18,8 +19,9 @@
1819
from .test_llm import (_test_llm_capture_request_error, get_model_path,
1920
global_kvcache_config, llama_model_path,
2021
llm_get_stats_async_test_harness,
21-
llm_get_stats_test_harness, llm_test_harness, prompts,
22-
run_llm_abort_request,
22+
llm_get_stats_test_harness,
23+
llm_return_logprobs_test_harness, llm_test_harness,
24+
prompts, run_llm_abort_request,
2325
run_llm_with_postprocess_parallel_and_result_handler,
2426
tinyllama_logits_processor_test_harness)
2527
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
@@ -836,3 +838,45 @@ def test_max_num_token_check(self):
836838
match="should not exceed max_num_tokens"):
837839
ids = [random.randint(10, 100) for _ in range(101)]
838840
llm.generate([ids])
841+
842+
843+
@pytest.mark.parametrize(
844+
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",
845+
[
846+
(2, None, True, False,
847+
"pytorch"), # prompt_logprobs with context_logits
848+
(None, 1, False, False,
849+
"pytorch"), # generation logprobs only (top-1, PyTorch limit)
850+
(2, None, False, False,
851+
"pytorch"), # prompt_logprobs without context_logits
852+
(None, None, False, False, "pytorch"), # no logprobs at all
853+
])
854+
def test_llm_return_logprobs(prompt_logprobs: Optional[int],
855+
logprobs: Optional[int],
856+
return_context_logits: bool,
857+
return_generation_logits: bool, backend: str):
858+
llm_return_logprobs_test_harness(prompt_logprobs,
859+
logprobs,
860+
return_context_logits,
861+
return_generation_logits,
862+
backend=backend)
863+
864+
865+
@pytest.mark.parametrize(
866+
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits",
867+
[
868+
(None, 1, False,
869+
False), # generation logprobs only (top-1, PyTorch limit)
870+
(2, None, True, False), # prompt_logprobs with context_logits
871+
(2, None, False, False), # prompt_logprobs only
872+
(2, 1, False, False), # both prompt and generation logprobs
873+
])
874+
def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs,
875+
return_context_logits,
876+
return_generation_logits):
877+
llm_return_logprobs_test_harness(prompt_logprobs,
878+
logprobs,
879+
return_context_logits,
880+
return_generation_logits,
881+
streaming=True,
882+
backend="pytorch")

0 commit comments

Comments
 (0)