@@ -473,10 +473,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
473
473
finish_reasons : torch .Tensor
474
474
sequence_lengths : torch .Tensor
475
475
cum_log_probs : torch .Tensor | None = None
476
+ gathered_ids : torch .Tensor | None = None
476
477
477
478
478
479
@dataclass (kw_only = True )
479
480
class SampleStateTRTLLM (SampleState ):
481
+ finalize_events : dict [str , CudaEvent ]
480
482
host : SampleStateTensorsHostTRTLLM
481
483
482
484
@@ -672,6 +674,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
672
674
self .store ["decoder_state" ],
673
675
self .store ["decoding_input" ][self .micro_batch_idx ])
674
676
677
+ finalize_events = {}
678
+ gathered_ids = None
679
+ if beam_width > 1 :
680
+ finished_sum_device = self .store ["decoder_state" ].finished_sum
681
+
682
+ for request in scheduled_requests .all_requests ():
683
+ if request .is_context_init_state :
684
+ continue
685
+ if finished_sum_device [request .seq_slot ] == beam_width :
686
+ finalize_events [
687
+ request .request_id ] = self ._finalize_request (
688
+ request , False )
689
+ elif request .streaming :
690
+ finalize_events [
691
+ request .request_id ] = self ._finalize_request (
692
+ request , True )
693
+ gathered_ids = self .store ["decoder_state" ].gathered_ids .to (
694
+ 'cpu' , non_blocking = True )
675
695
new_output_tokens = self .store ["decoder_state" ].all_new_tokens .to (
676
696
'cpu' , non_blocking = True )
677
697
finished_sum = self .store ["decoder_state" ].finished_sum .to (
@@ -698,7 +718,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
698
718
finish_reasons = finish_reasons ,
699
719
sequence_lengths = sequence_lengths ,
700
720
log_probs = log_probs ,
701
- cum_log_probs = cum_log_probs )
721
+ cum_log_probs = cum_log_probs ,
722
+ gathered_ids = gathered_ids )
702
723
703
724
sampler_event = torch .cuda .Event ()
704
725
sampler_event .record ()
@@ -709,7 +730,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
709
730
return SampleStateTRTLLM (scheduled_requests = scheduled_requests ,
710
731
device = device ,
711
732
host = host ,
712
- sampler_event = sampler_event )
733
+ sampler_event = sampler_event ,
734
+ finalize_events = finalize_events )
713
735
714
736
@torch .inference_mode ()
715
737
def update_requests (self , state : SampleStateTRTLLM ):
@@ -797,7 +819,7 @@ def update_requests_multiple_beams_or_drafting(self,
797
819
) if state .host .cum_log_probs is not None else None
798
820
log_probs_host = state .host .log_probs .tolist (
799
821
) if state .host .log_probs is not None else None
800
- finalize_events = {}
822
+ finalize_events = state . finalize_events
801
823
802
824
reqs = [
803
825
r for r in state .scheduled_requests .context_requests
@@ -865,19 +887,9 @@ def update_requests_multiple_beams_or_drafting(self,
865
887
866
888
if finished_sum_host [seq_slot ] == beam_width :
867
889
request .state = LlmRequestState .GENERATION_COMPLETE
868
- if beam_width > 1 :
869
- finalize_events [
870
- request .request_id ] = self ._finalize_request (
871
- request , False )
872
- elif request .streaming and beam_width > 1 :
873
- finalize_events [request .request_id ] = self ._finalize_request (
874
- request , True )
875
- # post process all requests if necessary
876
- if beam_width > 1 :
877
- for request in reqs :
878
- if request .request_id in finalize_events :
879
- self ._post_process_request (
880
- request , finalize_events [request .request_id ])
890
+ for request in reqs :
891
+ if request .request_id in finalize_events :
892
+ self ._post_process_request (request , state )
881
893
882
894
def _finalize_request (self , request : LlmRequest , streaming : bool ):
883
895
""" Finalizes the request. This is necessary for beam search. """
@@ -888,25 +900,24 @@ def _finalize_request(self, request: LlmRequest, streaming: bool):
888
900
return event
889
901
890
902
def _post_process_request (self , request : LlmRequest ,
891
- finalize_event : CudaEvent ):
903
+ state : SampleStateTRTLLM ):
892
904
""" Post Process the request. Updates the sequence according to the beam search results.
893
905
request: LlmRequest which shall be post processed
894
906
finalize_event: CudaEvent to wait for the finalize step to finish
895
907
"""
896
908
seq_slot = request .py_seq_slot
897
909
beam_width = request .sampling_config .beam_width
898
910
# synchronize on the finalize event before continuing the post processing.
899
- finalize_event .synchronize ()
911
+ # should be unnecessary, as already wait for the sampler event in update_requests
912
+ state .finalize_events [request .request_id ].synchronize ()
900
913
901
914
# Get these values again, as they might have changed during the finalize step
902
- output_ids_host = self .store ["decoder_state" ].gathered_ids .to ('cpu' )
903
- sequence_lengths_host = self .store ["decoder_state" ].sequence_lengths .to (
904
- 'cpu' )
915
+ output_ids_host = state .host .gathered_ids
916
+ sequence_lengths_host = state .host .sequence_lengths
905
917
906
918
if request .py_return_log_probs :
907
- log_probs_host = self .store ["decoder_state" ].log_probs .to ('cpu' )
908
- cum_log_probs_host = self .store ["decoder_state" ].cum_log_probs .to (
909
- 'cpu' )
919
+ log_probs_host = state .host .log_probs
920
+ cum_log_probs_host = state .host .cum_log_probs
910
921
911
922
generated_tokens = [[0 ]] * beam_width
912
923
log_probs = [[] for _ in range (beam_width )]
0 commit comments