@@ -1323,7 +1323,6 @@ def previous_seq_slots_device():
13231323
13241324 num_tokens = len (input_ids )
13251325 num_draft_tokens = len (draft_tokens )
1326- num_requests = len (request_ids )
13271326 total_num_tokens = len (position_ids )
13281327 assert total_num_tokens <= self .max_num_tokens , (
13291328 "total_num_tokens should be less than or equal to max_num_tokens" )
@@ -1340,6 +1339,10 @@ def previous_seq_slots_device():
13401339 self .draft_tokens_cuda [:len (draft_tokens )].copy_ (draft_tokens ,
13411340 non_blocking = True )
13421341 if next_draft_tokens_device is not None :
1342+ # Initialize these two values to zeros
1343+ self .previous_pos_id_offsets_cuda *= 0
1344+ self .previous_kv_lens_offsets_cuda *= 0
1345+
13431346 if previous_batch_len > 0 :
13441347 previous_slots = previous_seq_slots_device ()
13451348 # previous input ids
@@ -1364,24 +1367,37 @@ def previous_seq_slots_device():
13641367 pin_memory = True )
13651368 self .previous_pos_indices_cuda [0 :previous_batch_tokens ].copy_ (
13661369 previous_pos_indices_host , non_blocking = True )
1370+
1371+ # The order of requests in a batch: [context requests, generation requests]
1372+ # generation requests: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests']
1373+ # 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1374+ # 2) 'requests that already have previous batch': previous iteration's requests.
1375+ # 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1376+ # Therefore, both of self.previous_pos_id_offsets_cuda and self.previous_kv_lens_offsets_cuda are also 3 segments.
1377+ # For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1378+ # Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs.
1379+ # Already set to '0' during initialization.
1380+ # For 2) 'requests that already have previous batch': enable overlap scheduler.
1381+ # Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device.
1382+ # For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1383+ # Already set to '0' during initialization.
1384+
1385+ num_extend_reqeust_wo_dummy = len (extend_requests ) - len (
1386+ extend_dummy_requests )
13671387 self .previous_pos_id_offsets_cuda [
1368- 0 :previous_batch_tokens ].copy_ (
1388+ (num_extend_reqeust_wo_dummy - previous_batch_len ) *
1389+ (1 + self .max_draft_len ):num_extend_reqeust_wo_dummy *
1390+ (1 + self .max_draft_len )].copy_ (
13691391 new_tokens_lens_device [self .previous_pos_indices_cuda [
13701392 0 :previous_batch_tokens ]],
13711393 non_blocking = True )
1372- self .previous_kv_lens_offsets_cuda [0 :previous_batch_len ].copy_ (
1373- kv_len_offsets_device [previous_slots ], non_blocking = True )
1374- # for the requests that do not have previous batch, set the previous_pos_id_offsets and
1375- # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1376- self .previous_pos_id_offsets_cuda [
1377- previous_batch_tokens :num_requests *
1378- (1 + self .max_draft_len )] *= 0
1394+
13791395 self .previous_kv_lens_offsets_cuda [
1380- previous_batch_len : num_requests ] *= 0
1381- else :
1382- # change the data to zeros to skip the value changes in _preprocess_inputs
1383- self . previous_pos_id_offsets_cuda *= 0
1384- self . previous_kv_lens_offsets_cuda *= 0
1396+ num_extend_reqeust_wo_dummy -
1397+ previous_batch_len : num_extend_reqeust_wo_dummy ]. copy_ (
1398+ kv_len_offsets_device [ previous_slots ],
1399+ non_blocking = True )
1400+
13851401 elif new_tokens_device is not None :
13861402 seq_slots_device = previous_seq_slots_device ()
13871403 max_draft_len = max (draft_lens )
0 commit comments