Skip to content

Commit 2486eb7

Browse files
authored
[TRTLLM-6651][feat] Enable Overlap scheduler + Beam Search in TRTLLM Sampler (#6223)
Signed-off-by: Stefan Niebler <[email protected]>
1 parent 2b0fa24 commit 2486eb7

File tree

3 files changed

+107
-27
lines changed

3 files changed

+107
-27
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,6 @@ def __init__(self,
239239
self.event_loop = self._executor_loop_pp
240240
else:
241241
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
242-
if not disable_overlap_scheduler and model_engine.max_beam_width > 1:
243-
raise NotImplementedError(
244-
"Overlap scheduler is not supported for beam search.")
245242
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
246243
self.event_loop = trace_func(self.event_loop)
247244

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
473473
finish_reasons: torch.Tensor
474474
sequence_lengths: torch.Tensor
475475
cum_log_probs: torch.Tensor | None = None
476+
gathered_ids: torch.Tensor | None = None
476477

477478

478479
@dataclass(kw_only=True)
479480
class SampleStateTRTLLM(SampleState):
481+
finalize_events: dict[str, CudaEvent]
480482
host: SampleStateTensorsHostTRTLLM
481483

482484

@@ -672,6 +674,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
672674
self.store["decoder_state"],
673675
self.store["decoding_input"][self.micro_batch_idx])
674676

677+
finalize_events = {}
678+
gathered_ids = None
679+
if beam_width > 1:
680+
finished_sum_device = self.store["decoder_state"].finished_sum
681+
682+
for request in scheduled_requests.all_requests():
683+
if request.is_context_init_state:
684+
continue
685+
if finished_sum_device[request.seq_slot] == beam_width:
686+
finalize_events[
687+
request.request_id] = self._finalize_request(
688+
request, False)
689+
elif request.streaming:
690+
finalize_events[
691+
request.request_id] = self._finalize_request(
692+
request, True)
693+
gathered_ids = self.store["decoder_state"].gathered_ids.to(
694+
'cpu', non_blocking=True)
675695
new_output_tokens = self.store["decoder_state"].all_new_tokens.to(
676696
'cpu', non_blocking=True)
677697
finished_sum = self.store["decoder_state"].finished_sum.to(
@@ -698,7 +718,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
698718
finish_reasons=finish_reasons,
699719
sequence_lengths=sequence_lengths,
700720
log_probs=log_probs,
701-
cum_log_probs=cum_log_probs)
721+
cum_log_probs=cum_log_probs,
722+
gathered_ids=gathered_ids)
702723

703724
sampler_event = torch.cuda.Event()
704725
sampler_event.record()
@@ -709,7 +730,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
709730
return SampleStateTRTLLM(scheduled_requests=scheduled_requests,
710731
device=device,
711732
host=host,
712-
sampler_event=sampler_event)
733+
sampler_event=sampler_event,
734+
finalize_events=finalize_events)
713735

714736
@torch.inference_mode()
715737
def update_requests(self, state: SampleStateTRTLLM):
@@ -797,7 +819,7 @@ def update_requests_multiple_beams_or_drafting(self,
797819
) if state.host.cum_log_probs is not None else None
798820
log_probs_host = state.host.log_probs.tolist(
799821
) if state.host.log_probs is not None else None
800-
finalize_events = {}
822+
finalize_events = state.finalize_events
801823

802824
reqs = [
803825
r for r in state.scheduled_requests.context_requests
@@ -865,19 +887,9 @@ def update_requests_multiple_beams_or_drafting(self,
865887

866888
if finished_sum_host[seq_slot] == beam_width:
867889
request.state = LlmRequestState.GENERATION_COMPLETE
868-
if beam_width > 1:
869-
finalize_events[
870-
request.request_id] = self._finalize_request(
871-
request, False)
872-
elif request.streaming and beam_width > 1:
873-
finalize_events[request.request_id] = self._finalize_request(
874-
request, True)
875-
# post process all requests if necessary
876-
if beam_width > 1:
877-
for request in reqs:
878-
if request.request_id in finalize_events:
879-
self._post_process_request(
880-
request, finalize_events[request.request_id])
890+
for request in reqs:
891+
if request.request_id in finalize_events:
892+
self._post_process_request(request, state)
881893

882894
def _finalize_request(self, request: LlmRequest, streaming: bool):
883895
""" Finalizes the request. This is necessary for beam search. """
@@ -888,25 +900,24 @@ def _finalize_request(self, request: LlmRequest, streaming: bool):
888900
return event
889901

890902
def _post_process_request(self, request: LlmRequest,
891-
finalize_event: CudaEvent):
903+
state: SampleStateTRTLLM):
892904
""" Post Process the request. Updates the sequence according to the beam search results.
893905
request: LlmRequest which shall be post processed
894906
finalize_event: CudaEvent to wait for the finalize step to finish
895907
"""
896908
seq_slot = request.py_seq_slot
897909
beam_width = request.sampling_config.beam_width
898910
# synchronize on the finalize event before continuing the post processing.
899-
finalize_event.synchronize()
911+
# should be unnecessary, as already wait for the sampler event in update_requests
912+
state.finalize_events[request.request_id].synchronize()
900913

901914
# Get these values again, as they might have changed during the finalize step
902-
output_ids_host = self.store["decoder_state"].gathered_ids.to('cpu')
903-
sequence_lengths_host = self.store["decoder_state"].sequence_lengths.to(
904-
'cpu')
915+
output_ids_host = state.host.gathered_ids
916+
sequence_lengths_host = state.host.sequence_lengths
905917

906918
if request.py_return_log_probs:
907-
log_probs_host = self.store["decoder_state"].log_probs.to('cpu')
908-
cum_log_probs_host = self.store["decoder_state"].cum_log_probs.to(
909-
'cpu')
919+
log_probs_host = state.host.log_probs
920+
cum_log_probs_host = state.host.cum_log_probs
910921

911922
generated_tokens = [[0]] * beam_width
912923
log_probs = [[] for _ in range(beam_width)]

tests/unittest/_torch/test_beam_search.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ def llm(fixed_params, input_prompts):
5151
)
5252

5353

54+
@pytest.fixture(scope="module")
55+
def llm_overlap(fixed_params, input_prompts):
56+
return LLM(
57+
model=os.path.join(llm_models_root(), "llama-models-v2",
58+
"TinyLlama-1.1B-Chat-v1.0"),
59+
kv_cache_config=KvCacheConfig(max_tokens=10000),
60+
max_batch_size=fixed_params["max_beam_width"] * len(
61+
input_prompts
62+
), # use small batch size to prevent large buffers from possibly hiding wrong data accesses.
63+
max_seq_len=32,
64+
enable_trtllm_sampler=True,
65+
max_beam_width=fixed_params["max_beam_width"],
66+
disable_overlap_scheduler=False,
67+
#TODO: remove this once we have a proper fix for CUDA graph in beam search
68+
cuda_graph_config=None,
69+
)
70+
71+
5472
@force_ampere # Save H100 resource
5573
@pytest.mark.parametrize("return_log_probs", [True, False])
5674
@pytest.mark.parametrize("gather_generation_logits", [True, False])
@@ -105,3 +123,57 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
105123
assert similar(
106124
beam.text,
107125
expected_outputs[input_prompts[output_idx]][beam_idx])
126+
127+
128+
@force_ampere # Save H100 resource
129+
@pytest.mark.parametrize("return_log_probs", [True, False])
130+
@pytest.mark.parametrize("gather_generation_logits", [True, False])
131+
@pytest.mark.parametrize("gather_context_logits", [True, False])
132+
@pytest.mark.parametrize("num_output_beams", [1, 2])
133+
@pytest.mark.parametrize("num_prompts", [1, 2])
134+
@pytest.mark.threadleak(enabled=False)
135+
def test_beam_search_output_shapes_overlap(
136+
gather_context_logits: bool, gather_generation_logits: bool,
137+
return_log_probs: bool, num_output_beams: int, num_prompts: int,
138+
llm_overlap, fixed_params, input_prompts, expected_outputs):
139+
if return_log_probs and num_prompts > 1:
140+
pytest.skip(
141+
"Beam search currently does not support return_log_probs with multiple prompts"
142+
)
143+
sampling_params = SamplingParams(
144+
max_tokens=fixed_params["max_tokens"],
145+
n=num_output_beams,
146+
best_of=fixed_params["max_beam_width"],
147+
use_beam_search=True,
148+
return_context_logits=gather_context_logits,
149+
return_generation_logits=gather_generation_logits,
150+
logprobs=return_log_probs,
151+
)
152+
outputs = llm_overlap.generate(input_prompts[:num_prompts],
153+
sampling_params=sampling_params)
154+
assert len(outputs) == num_prompts
155+
for output_idx, output in enumerate(outputs):
156+
if gather_context_logits:
157+
assert output.context_logits is not None
158+
assert len(
159+
output.prompt_token_ids) == output.context_logits.shape[0]
160+
else:
161+
assert output.context_logits is None
162+
assert len(output.outputs) == num_output_beams
163+
for beam_idx, beam in enumerate(output.outputs):
164+
if gather_generation_logits:
165+
gen_logits = beam.generation_logits
166+
assert gen_logits is not None
167+
assert gen_logits.ndim == 2
168+
assert gen_logits.shape[0] == sampling_params.max_tokens
169+
else:
170+
assert beam.generation_logits is None
171+
172+
if return_log_probs:
173+
assert len(beam.logprobs) == sampling_params.max_tokens
174+
else:
175+
assert len(beam.logprobs) == 0
176+
# Check output similarity
177+
assert similar(
178+
beam.text,
179+
expected_outputs[input_prompts[output_idx]][beam_idx])

0 commit comments

Comments
 (0)