Skip to content

Commit 227fed5

Browse files
committed
add helpful comments
Signed-off-by: Venky Ganesh <[email protected]>
1 parent cecdd61 commit 227fed5

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

tensorrt_llm/executor/result.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,16 @@ def _handle_sequence(self,
244244
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx]
245245

246246
if logprobs_result:
247-
output.prompt_logprobs = logprobs_result.prompt
247+
output.prompt_logprobs = logprobs_result.prompt # both backends
248+
# this line only matters for TRT backend, where generation logprobs are
249+
# calculated outside engine in logprobs_result.generation.
250+
# for pytorch backend, generation logprobs are calculated in sampler,
251+
# and are provided by response_tensors.log_probs in the following lines.
248252
output.logprobs = logprobs_result.generation
249253

250254
if response_tensors.log_probs is not None:
255+
# response_tensors.log_probs has per-token generation logprobs
256+
# that are coupled to the sampling strategy, hence is provided by sampler.
251257
output._last_logprobs_len = len(
252258
output.logprobs) if output.logprobs is not None else 0
253259
output.logprobs = response_tensors.log_probs[src_idx]
@@ -699,12 +705,12 @@ def compute_logprobs(
699705
output_token_ids: Optional[list[int]],
700706
) -> LogProbsResult:
701707
"""
702-
Compute top-K logprobs and ranks for each token position.
708+
Compute top-K logprobs from logits when engine doesn't provide them directly.
703709
704-
Returns:
705-
LogProbsResult, a NamedTuple containing:
706-
- prompt: Optional[List[Dict[token_id, Logprob]]] logprobs for prompt tokens.
707-
- generation: Optional[List[Dict[token_id, Logprob]]] logprobs for generated tokens.
710+
Used for post-processing logits into logprobs.
711+
- Prompt logprobs (from context_logits): always used.
712+
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT).
713+
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here.
708714
"""
709715

710716
def _topk_logprobs(logits: torch.Tensor, top_k: int,

tensorrt_llm/executor/worker.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,11 @@ def _get_params_for_first_rsp(
10531053
def _get_logprobs(worker,
10541054
response: Union[tllm.Response, LlmResponse],
10551055
is_pytorch_backend=False) -> Optional[LogProbsResult]:
1056-
"""Compute logprob and prompt logprob and clear out logits if applicable.
1056+
"""Compute logprobs from response logits when needed.
1057+
1058+
Logprobs provenance varies by backend:
1059+
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
1060+
- TRT: Both prompt and generation logprobs computed here from logits
10571061
"""
10581062

10591063
logprobs_result = None
@@ -1066,10 +1070,14 @@ def _get_logprobs(worker,
10661070
if logprob_params:
10671071
if is_pytorch_backend:
10681072
if not logprob_params.prompt_logprobs:
1069-
# generation logprobs are already calculated in PyTorch backend sampler
1073+
# PyTorch: generation logprobs computed in sampler, no post-processing needed
10701074
return
10711075
else:
1072-
# Fallback: compute from context_logits if available
1076+
# PyTorch: compute only prompt logprobs from context logits
1077+
# This can be done as a postprocessing step instead of coupling it to the
1078+
# pytorch engine, because prompt_logprobs calculation is not complicated by
1079+
# generation sampling strategies. Therefore it is simpler to do it here than
1080+
# doing it in the pytorch engine and plumbing it through the response.
10731081
logprobs_result = compute_logprobs(
10741082
logprob_params.prompt_logprobs, None,
10751083
response.result.context_logits, None, None)
@@ -1079,7 +1087,7 @@ def _get_logprobs(worker,
10791087
response.clear_context_logits()
10801088
return logprobs_result
10811089

1082-
# trt backend
1090+
# TRT backend: compute both prompt and generation logprobs from logits
10831091
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
10841092
logprob_params.logprobs,
10851093
response.result.context_logits,

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,12 +569,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
569569
is_gen_only: bool) -> None:
570570

571571
if self.args.backend in ["pytorch", "_autodeploy"]:
572-
# TODO: remove these checks after PyTorch backend
573-
# fully support TopK prompt and generation logprobs.
574-
# if sampling_params.prompt_logprobs:
575-
# raise ValueError(
576-
# f"`prompt_logprobs` in sampling_params is not supported in the PyTorch backend yet. Received `prompt_logprobs={sampling_params.prompt_logprobs}`. Please unset this field."
577-
# )
578572
if sampling_params.logprobs and sampling_params.logprobs > 1:
579573
raise ValueError(
580574
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."

tensorrt_llm/sampling_params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel
99

1010
from tensorrt_llm.bindings import executor as tllme
11+
from tensorrt_llm.logger import logger
1112

1213

1314
@dataclass(slots=True, kw_only=True)
@@ -453,6 +454,11 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo
453454
# we need to internally enable context logits for prompt logprobs computation
454455
# They will be dropped after computation if the user didn't explicitly request them
455456
if self.prompt_logprobs and not self.return_context_logits:
457+
logger.info(
458+
"Since prompt_logprobs is requested but return_context_logits is False, "
459+
"internally enabling context logits for prompt logprobs computation. "
460+
"context logits will be dropped after computation as the user didn't explicitly request them."
461+
)
456462
config_kwargs["return_context_logits"] = True
457463
else:
458464
config_kwargs["return_log_probs"] = self._return_log_probs

0 commit comments

Comments
 (0)