Skip to content
Merged
45 changes: 41 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ class LlmResponse:
def has_error(self):
return self.error_msg is not None

def clear_context_logits(self):
"""Clear context logits from the response result.

This is used to drop context logits after prompt_logprobs have been computed
when the user didn't explicitly request them.
"""
if self.result and hasattr(self.result, '_py_result'):
py_result = self.result._py_result
if hasattr(py_result, '_context_logits'):
py_result._context_logits = None


class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
Expand Down Expand Up @@ -377,10 +388,36 @@ def __init__(
def is_generation_only_request(self):
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY

def create_response(
self,
use_fast_logits=False,
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
def create_response(self,
use_fast_logits=False,
mpi_world_rank=0) -> LlmResponse | None:
"""Create an LlmResponse from the current request state.

This method generates a response containing the request's execution results,
including generated tokens, logits, and completion status. It wraps the
parent class's serialized result in a PyTorch-specific LlmResponse object.

Args:
use_fast_logits (bool, optional, default=False): Only applicable for TRT-backend with speculative decoding enabled. When returning generation logits under speculative decoding,
`use_fast_logits=True` replaces tensor payloads with tiny metadata so the target pulls logits
directly (zero-copy/IPC), reducing overhead; ignored on PyTorch.
mpi_world_rank (int, optional, default=0): Only applicable for TRT-backend, with speculative decoding
enabled, and `use_fast_logits=True`. Contains the MPI world rank of the process containing the draft
model, that produces the generation logits. This helps transfer logits from the draft model to the
target model without going through the serialization/transport path.

Returns:
LlmResponse | None: An LlmResponse object containing the request results
if there is valid output, otherwise None.
The response includes:
- request_id: The request identifier (parent ID for child requests)
- result: LlmResult wrapping both serialized and PyTorch-specific results
- client_id: The client identifier for request routing

Note:
Returns None if the serialized result is empty (len(result) == 0),
indicating no output was generated for this request iteration.
"""
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)
return LlmResponse(
Expand Down
51 changes: 44 additions & 7 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from tensorrt_llm.logger import logger

from .._torch.pyexecutor.llm_request import LlmResponse
from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank,
nvtx_range_debug)
from ..bindings import executor as tllm
Expand Down Expand Up @@ -701,24 +702,60 @@ def _get_params_for_first_rsp(
return None, None


def _compute_pytorch_prompt_logprobs(
generation_result: GenerationResult,
response: LlmResponse) -> Optional[LogProbsResult]:
"""Compute prompt logprobs for PyTorch backend (cached when streaming) """
logprob_params = generation_result._logprob_params # should be present and non None
assert logprob_params is not None
if generation_result._streaming:
cached = getattr(generation_result, '_cached_prompt_logprobs', None)
if cached is not None:
return LogProbsResult(
prompt=cached, generation=None
) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler.
context_logits = response.result.context_logits
assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested."
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None,
context_logits, None, None)
if generation_result._streaming:
generation_result._cached_prompt_logprobs = logprobs_result.prompt

return logprobs_result


def _get_logprobs(worker,
response: tllm.Response,
response: Union[tllm.Response, LlmResponse],
is_pytorch_backend=False) -> Optional[LogProbsResult]:
"""Compute logprob and prompt logprob and clear out logits if applicable.
"""Compute logprobs from response logits when needed.

Logprobs provenance varies by backend:
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
- TRT: Both prompt and generation logprobs computed here from logits
"""
if is_pytorch_backend:
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
# In the PyTorch backend, logprobs are already computed during runtime if requested.
return None

logprobs_result = None
generation_result = worker._results.get(response.client_id, None)

if not generation_result:
return
return None

logprob_params = getattr(generation_result, "_logprob_params", None)
if logprob_params:
if is_pytorch_backend:
if not logprob_params.prompt_logprobs:
# PyTorch: generation logprobs computed in sampler, no post-processing needed
return None
else:
logprobs_result = _compute_pytorch_prompt_logprobs(
generation_result, response)

if logprob_params.drop_context_logits:
response.clear_context_logits()

return logprobs_result

# TRT backend: compute both prompt and generation logprobs from logits
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
logprob_params.logprobs,
response.result.context_logits,
Expand Down
31 changes: 25 additions & 6 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,36 @@ def _handle_sequence(self,
if response_tensors.cum_log_probs is not None:
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx]

if logprobs_result:
# prompt logprobs handling
if logprobs_result and logprobs_result.prompt is not None: # both backends
output.prompt_logprobs = logprobs_result.prompt
# generation logprobs handling (provenance varies by backend)
if logprobs_result and logprobs_result.generation is not None: # TRT backend
# update logprobs from ResponseWrapper (TRT top logprobs WAR)
output._last_logprobs_len = len(output.logprobs)
output.prompt_logprobs = logprobs_result.prompt
output.logprobs += logprobs_result.generation
elif response_tensors.log_probs is not None:
# handle logprobs directly from response tensors
elif response_tensors.log_probs is not None: # PyTorch backend
# handle logprobs directly from response tensors given by sampler
output._last_logprobs_len = len(output.logprobs)
output.logprobs = response_tensors.log_probs[src_idx]
# In streaming mode, since out-of-order responses are not possible,
# each streamed response_tensors.log_probs[src_idx]
# contains a streamwise monotonically growing list of logprobs.
# so we need to accumulate only the new ones unique to that particular streamed response
assert output._last_logprobs_len <= len(
response_tensors.log_probs[src_idx]
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
f"{len(response_tensors.log_probs[src_idx])})")
output.logprobs += response_tensors.log_probs[src_idx][
output._last_logprobs_len:]
# overcome some WAR in the cpp executor
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
# Check if logprobs is a list (not a dict or other structure)
if len(output.logprobs) > output.length:
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
output.logprobs = output.logprobs[:output.length]
assert len(output.logprobs) == output.length

if response_tensors.generation_logits is not None:
output.generation_logits = response_tensors.generation_logits[
src_idx, :output.length]
Expand Down Expand Up @@ -698,7 +712,12 @@ def compute_logprobs(
output_token_ids: Optional[list[int]],
) -> LogProbsResult:
"""
Compute top-K logprobs and ranks for each token position.
Compute top-K logprobs from logits when engine doesn't provide them directly.

Used for post-processing logits into logprobs.
- Prompt logprobs (from context_logits): always used.
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT).
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here.

Returns:
LogProbsResult, a NamedTuple containing:
Expand Down
6 changes: 0 additions & 6 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
is_gen_only: bool) -> None:

if self.args.backend in ["pytorch", "_autodeploy"]:
# TODO: remove these checks after PyTorch backend
# fully support TopK prompt and generation logprobs.
if sampling_params.prompt_logprobs:
raise ValueError(
f"`prompt_logprobs` in sampling_params is not supported in the PyTorch backend yet. Received `prompt_logprobs={sampling_params.prompt_logprobs}`. Please unset this field."
)
if sampling_params.logprobs and sampling_params.logprobs > 1:
raise ValueError(
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
Expand Down
15 changes: 15 additions & 0 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from tensorrt_llm.bindings import executor as tllme
from tensorrt_llm.logger import logger


@dataclass(slots=True, kw_only=True)
Expand Down Expand Up @@ -449,6 +450,20 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo

if is_pytorch_backend:
config_kwargs["return_log_probs"] = bool(self.logprobs)
if self.prompt_logprobs and not self.return_context_logits:
logger.info(
"Since prompt_logprobs is requested but return_context_logits is False, "
"internally enabling context logits for prompt logprobs computation. "
"context logits will be dropped after computation as the user didn't explicitly request them."
)
# TODO(venky): Find a more elegant way to do this.
# NOTE: This is an internal hack, so we can entirely avoid introducing
# `prompt_logprobs` into the executor bindings and further into
# model engine / sampler.
# This is because, prompt_logprobs is a derived quantity from
# context logits, and the capability to post-compute it
# already exists in the worker. (see _get_logprobs in worker.py)
config_kwargs["return_context_logits"] = True
else:
config_kwargs["return_log_probs"] = self._return_log_probs

Expand Down
31 changes: 22 additions & 9 deletions tests/unittest/llmapi/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,14 +1804,20 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int],
backend=None):
LLM_CLASS = LLM
llm_args_extra = {}
kv_cache_args_extra = {}
if backend in ["pytorch", "autodeploy"]:
LLM_CLASS = LLM_torch
if streaming:
# need this so that context_logits / prompt_logprobs are not dropped
# in the 2nd reuse of llm.generate() in streaming mode
kv_cache_args_extra["enable_block_reuse"] = False
else:
llm_args_extra["fast_build"] = True

llm = LLM_CLASS(
llama_model_path,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4,
**kv_cache_args_extra),
build_config=BuildConfig(gather_context_logits=True),
tensor_parallel_size=tp_size,
gather_generation_logits=True,
Expand Down Expand Up @@ -1864,7 +1870,7 @@ async def task(id: int, prompt: str):
async for output in llm.generate_async(prompt,
sampling_params,
streaming=True):
logprobs_result_streaming += output.outputs[0].logprobs
logprobs_result_streaming += output.outputs[0].logprobs_diff

# comparing streaming logprobs result to non-streaming
assert logprobs_result_streaming == logprobs_result
Expand All @@ -1877,21 +1883,28 @@ async def main():
asyncio.run(main())


@pytest.mark.skip(reason="https://nvbugs/5516660")
@force_ampere
@pytest.mark.parametrize(
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits",
[(2, None, True, False), (None, 2, False, False)])
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",
[
# TRT backend test cases
(2, None, True, False, "trt"), # prompt_logprobs with context_logits
(None, 2, False, False, "trt"), # generation logprobs only (top-2)
(2, None, False, False,
"trt"), # prompt_logprobs without context_logits
(None, None, False, False, "trt"), # no logprobs at all
])
def test_llm_return_logprobs(prompt_logprobs: Optional[int],
logprobs: Optional[int],
return_context_logits: bool,
return_generation_logits: bool):
llm_return_logprobs_test_harness(prompt_logprobs, logprobs,
return_generation_logits: bool, backend: str):
llm_return_logprobs_test_harness(prompt_logprobs,
logprobs,
return_context_logits,
return_generation_logits)
return_generation_logits,
backend=backend)


@pytest.mark.skip(reason="https://nvbugs/5516660")
@force_ampere
def test_llm_return_logprobs_streaming():
llm_return_logprobs_test_harness(2, 2, False, True, streaming=True)
Expand Down
48 changes: 46 additions & 2 deletions tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from contextlib import contextmanager, nullcontext
from typing import Optional

import pytest

Expand All @@ -19,8 +20,9 @@
from .test_llm import (_test_llm_capture_request_error, get_model_path,
global_kvcache_config, llama_model_path,
llm_get_stats_async_test_harness,
llm_get_stats_test_harness, llm_test_harness, prompts,
run_llm_abort_request,
llm_get_stats_test_harness,
llm_return_logprobs_test_harness, llm_test_harness,
prompts, run_llm_abort_request,
run_llm_with_postprocess_parallel_and_result_handler,
tinyllama_logits_processor_test_harness)
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
Expand Down Expand Up @@ -892,3 +894,45 @@ def test_min_tokens(use_speculative: bool):

assert len(res.outputs) == 1
assert len(res.outputs[0].token_ids) == output_len


@pytest.mark.parametrize(
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",
[
(2, None, True, False,
"pytorch"), # prompt_logprobs with context_logits
(None, 1, False, False,
"pytorch"), # generation logprobs only (top-1, PyTorch limit)
(2, None, False, False,
"pytorch"), # prompt_logprobs without context_logits
(None, None, False, False, "pytorch"), # no logprobs at all
])
def test_llm_return_logprobs(prompt_logprobs: Optional[int],
logprobs: Optional[int],
return_context_logits: bool,
return_generation_logits: bool, backend: str):
llm_return_logprobs_test_harness(prompt_logprobs,
logprobs,
return_context_logits,
return_generation_logits,
backend=backend)


@pytest.mark.parametrize(
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits",
[
(None, 1, False,
False), # generation logprobs only (top-1, PyTorch limit)
(2, None, True, False), # prompt_logprobs with context_logits
(2, None, False, False), # prompt_logprobs only
(2, 1, False, False), # both prompt and generation logprobs
])
def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs,
return_context_logits,
return_generation_logits):
llm_return_logprobs_test_harness(prompt_logprobs,
logprobs,
return_context_logits,
return_generation_logits,
streaming=True,
backend="pytorch")