14
14
from vllm .config import CUDAGraphMode , ParallelConfig , VllmConfig
15
15
from vllm .logger import init_logger
16
16
from vllm .platforms import current_platform
17
+ from vllm .v1 .worker .ubatch_utils import UBatchSlices , is_second_ubatch_empty
17
18
18
19
if TYPE_CHECKING :
19
20
from vllm .attention .backends .abstract import AttentionMetadata
@@ -97,6 +98,53 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
97
98
dist .all_reduce (num_tokens_tensor , group = group )
98
99
return num_tokens_tensor .cpu ()
99
100
101
+ @staticmethod
102
+ def should_ubatch_across_dp (
103
+ should_ubatch : bool , orig_num_tokens_per_ubatch : int ,
104
+ padded_num_tokens_per_ubatch : int , dp_size : int ,
105
+ dp_rank : int ) -> tuple [bool , Optional [torch .Tensor ]]:
106
+ """
107
+ 1. Decides if each DP rank is going to microbatch. Either all ranks
108
+ run with microbatching or none of them do. If this function decides
109
+ not to run with microbatching. It will "abort" meaning that no padding
110
+ information will be returned to the caller. It will return (False, None)
111
+
112
+ 2. Determines the total number of tokens that each rank will run.
113
+ All ranks will be padded out so that the run with the same number
114
+ of tokens
115
+
116
+ Returns: tuple[
117
+ should_ubatch: Are all DP ranks going to microbatch
118
+ num_tokens_after_padding: A tensor containing the total number of
119
+ tokens per-microbatch for each DP rank including padding. Will be
120
+ None if should_ubatch if False
121
+ ]
122
+ """
123
+
124
+ device = current_platform .device_type
125
+ tensor = torch .zeros (3 , dp_size , device = device , dtype = torch .int32 )
126
+ tensor [0 ][dp_rank ] = orig_num_tokens_per_ubatch
127
+ tensor [1 ][dp_rank ] = padded_num_tokens_per_ubatch
128
+ tensor [2 ][dp_rank ] = 1 if should_ubatch else 0
129
+
130
+ from vllm .distributed .parallel_state import get_dp_group
131
+ dist .all_reduce (tensor , group = get_dp_group ().device_group )
132
+
133
+ result : bool = bool (torch .all (tensor [2 ] == 1 ).item ())
134
+ if not result :
135
+ return result , None
136
+
137
+ orig_num_tokens_tensor = tensor [0 , :]
138
+ padded_num_tokens_tensor = tensor [1 , :]
139
+
140
+ orig_min_num_tokens = int (orig_num_tokens_tensor .min ().item ())
141
+ padded_max_num_tokens = int (padded_num_tokens_tensor .max ().item ())
142
+ if is_second_ubatch_empty (orig_min_num_tokens , padded_max_num_tokens ):
143
+ logger .debug ("Aborting ubatching %s %s" , orig_min_num_tokens ,
144
+ padded_max_num_tokens )
145
+ return False , None
146
+ return result , padded_num_tokens_tensor .cpu ()
147
+
100
148
@staticmethod
101
149
def make (
102
150
parallel_config : ParallelConfig ,
@@ -119,14 +167,15 @@ def make(
119
167
120
168
# If num_tokens_across_dp is None, it will be computed by all_reduce
121
169
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
122
- assert (num_tokens_across_dp is None
123
- or num_tokens_across_dp [dp_rank ] == batchsize )
170
+ assert (num_tokens_across_dp is None or num_tokens_across_dp [ dp_rank ]
171
+ == batchsize ), f" { num_tokens_across_dp [dp_rank ]} { batchsize } "
124
172
if num_tokens_across_dp is None :
125
173
num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
126
174
batchsize , dp_size , dp_rank )
127
175
max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp )
128
176
cu_tokens_across_dp_cpu = torch .cumsum (num_tokens_across_dp , dim = 0 )
129
- return DPMetadata (max_tokens_across_dp_cpu , cu_tokens_across_dp_cpu )
177
+ return DPMetadata (max_tokens_across_dp_cpu , cu_tokens_across_dp_cpu ,
178
+ num_tokens_across_dp )
130
179
131
180
@contextmanager
132
181
def chunked_sizes (self , max_chunk_size_per_rank : int , chunk_idx : int ):
@@ -179,9 +228,12 @@ class ForwardContext:
179
228
Type AttentionMetadata for v0,
180
229
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
181
230
attention layer to its attention metadata
182
- set dynamically for each forward pass
231
+ Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
232
+ for each microbatch.
233
+ Set dynamically for each forward pass
183
234
"""
184
- attn_metadata : Union ["AttentionMetadata" , dict [str , "AttentionMetadata" ]]
235
+ attn_metadata : Union ["AttentionMetadata" , dict [str , "AttentionMetadata" ],
236
+ list [dict [str , "AttentionMetadata" ]]]
185
237
# TODO: remove after making all virtual_engines share the same kv cache
186
238
virtual_engine : int # set dynamically for each forward pass
187
239
# set dynamically for each forward pass
@@ -191,6 +243,8 @@ class ForwardContext:
191
243
cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE
192
244
batch_descriptor : Optional [BatchDescriptor ] = None
193
245
246
+ ubatch_slices : Optional [UBatchSlices ] = None
247
+
194
248
def __post_init__ (self ):
195
249
assert self .cudagraph_runtime_mode in [
196
250
CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL ], \
@@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
208
262
return _forward_context
209
263
210
264
265
+ def create_forward_context (
266
+ attn_metadata : Any ,
267
+ vllm_config : VllmConfig ,
268
+ virtual_engine : int = 0 ,
269
+ dp_metadata : Optional [DPMetadata ] = None ,
270
+ cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE ,
271
+ batch_descriptor : Optional [BatchDescriptor ] = None ,
272
+ ubatch_slices : Optional [UBatchSlices ] = None ):
273
+ return ForwardContext (no_compile_layers = vllm_config .compilation_config .
274
+ static_forward_context ,
275
+ virtual_engine = virtual_engine ,
276
+ attn_metadata = attn_metadata ,
277
+ dp_metadata = dp_metadata ,
278
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
279
+ batch_descriptor = batch_descriptor ,
280
+ ubatch_slices = ubatch_slices )
281
+
282
+
283
+ @contextmanager
284
+ def override_forward_context (forward_context : Optional [ForwardContext ]):
285
+ """A context manager that overrides the current forward context.
286
+ This is used to override the forward context for a specific
287
+ forward pass.
288
+ """
289
+ global _forward_context
290
+ prev_context = _forward_context
291
+ _forward_context = forward_context
292
+ try :
293
+ yield
294
+ finally :
295
+ _forward_context = prev_context
296
+
297
+
211
298
@contextmanager
212
299
def set_forward_context (
213
300
attn_metadata : Any ,
@@ -216,7 +303,8 @@ def set_forward_context(
216
303
num_tokens : Optional [int ] = None ,
217
304
num_tokens_across_dp : Optional [torch .Tensor ] = None ,
218
305
cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE ,
219
- batch_descriptor : Optional [BatchDescriptor ] = None ):
306
+ batch_descriptor : Optional [BatchDescriptor ] = None ,
307
+ ubatch_slices : Optional [UBatchSlices ] = None ):
220
308
"""A context manager that stores the current forward context,
221
309
can be attention metadata, etc.
222
310
Here we can inject common logic for every model forward pass.
@@ -225,27 +313,22 @@ def set_forward_context(
225
313
need_to_track_batchsize = track_batchsize and attn_metadata is not None
226
314
if need_to_track_batchsize :
227
315
forward_start_time = time .perf_counter ()
316
+
228
317
dp_metadata : Optional [DPMetadata ] = None
229
318
if vllm_config .parallel_config .data_parallel_size > 1 and (
230
319
attn_metadata is not None or num_tokens is not None ):
231
320
dp_metadata = DPMetadata .make (vllm_config .parallel_config ,
232
321
attn_metadata , num_tokens or 0 ,
233
322
num_tokens_across_dp )
234
323
235
- global _forward_context
236
- prev_context = _forward_context
237
- _forward_context = ForwardContext (
238
- no_compile_layers = vllm_config .compilation_config .
239
- static_forward_context ,
240
- virtual_engine = virtual_engine ,
241
- attn_metadata = attn_metadata ,
242
- dp_metadata = dp_metadata ,
243
- cudagraph_runtime_mode = cudagraph_runtime_mode ,
244
- batch_descriptor = batch_descriptor ,
245
- )
324
+ forward_context = create_forward_context (attn_metadata , vllm_config ,
325
+ virtual_engine , dp_metadata ,
326
+ cudagraph_runtime_mode ,
327
+ batch_descriptor , ubatch_slices )
246
328
247
329
try :
248
- yield
330
+ with override_forward_context (forward_context ):
331
+ yield
249
332
finally :
250
333
global last_logging_time , batchsize_logging_interval
251
334
if need_to_track_batchsize :
@@ -282,5 +365,3 @@ def set_forward_context(
282
365
logger .info (("Batchsize forward time stats "
283
366
"(batchsize, count, median_time(ms)): %s" ),
284
367
forward_stats )
285
-
286
- _forward_context = prev_context
0 commit comments