Skip to content

Commit 5679399

Browse files
SageMooreLucasWilkinsonyewentao256robertgshaw2-redhat
authored
[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (vllm-project#23693)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Sage Moore <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: yewentao256 <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: yewentao256 <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent 0836928 commit 5679399

File tree

22 files changed

+1255
-170
lines changed

22 files changed

+1255
-170
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def parse_args():
8787
default=0.8,
8888
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
8989
)
90+
parser.add_argument(
91+
"--enable-dbo",
92+
action="store_true",
93+
help=("Enable microbatched execution"),
94+
)
9095
parser.add_argument(
9196
"--compilation-config",
9297
type=int,
@@ -113,6 +118,7 @@ def main(
113118
max_model_len,
114119
compilation_config,
115120
gpu_memory_utilization,
121+
enable_dbo,
116122
quantization,
117123
):
118124
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
@@ -167,6 +173,7 @@ def start(rank):
167173
max_num_seqs=max_num_seqs,
168174
max_model_len=max_model_len,
169175
gpu_memory_utilization=gpu_memory_utilization,
176+
enable_dbo=enable_dbo,
170177
quantization=quantization,
171178
compilation_config=compilation_config,
172179
)
@@ -227,6 +234,7 @@ def start(rank):
227234
args.max_model_len,
228235
args.compilation_config,
229236
args.gpu_memory_utilization,
237+
args.enable_dbo,
230238
args.quantization,
231239
),
232240
)

tests/v1/attention/test_attention_splitting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from tests.v1.attention.test_attention_backends import BATCH_SPECS
88
from tests.v1.attention.utils import create_common_attn_metadata
9-
from vllm.v1.attention.backends.utils import (UbatchSlice,
9+
from vllm.v1.attention.backends.utils import (UBatchSlice,
1010
_make_metadata_with_slice,
1111
slice_query_start_locs,
1212
split_attn_metadata)
@@ -106,7 +106,7 @@ def mixed_small_metadata():
106106
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
107107
"""Test slicing decode batch metadata"""
108108
# Split first request only
109-
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
109+
ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1))
110110

111111
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
112112

@@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
120120

121121
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
122122
"""Test slicing mixed batch metadata"""
123-
ubatch_slice = UbatchSlice(slice(1, 3),
123+
ubatch_slice = UBatchSlice(slice(1, 3),
124124
slice(1, 7)) # Requests 1-3, tokens 1-7
125125

126126
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
@@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
137137
num_tokens = large_decode_metadata.num_reqs
138138
mid_point = num_tokens // 2
139139
ubatch_slices = [
140-
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
141-
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
140+
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
141+
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
142142
num_tokens)),
143143
]
144144

tests/v1/spec_decode/test_eagle.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ def create_deterministic_logits(token_ids):
365365
# Mock runner for attention metadata building
366366
proposer.runner = mock.MagicMock()
367367
proposer.runner.attn_groups.append([mock.MagicMock()])
368-
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
368+
proposer.runner.attn_groups[0][0].metadata_builders = [
369+
attn_metadata_builder
370+
]
369371

370372
result = proposer.propose(target_token_ids=target_token_ids,
371373
target_positions=target_positions,
@@ -489,7 +491,9 @@ def create_deterministic_logits(token_ids, k: int):
489491
# Mock runner for attention metadata building.
490492
proposer.runner = mock.MagicMock()
491493
proposer.runner.attn_groups.append([mock.MagicMock()])
492-
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
494+
proposer.runner.attn_groups[0][0].metadata_builders = [
495+
attn_metadata_builder
496+
]
493497

494498
# Setup inputs for the proposer.
495499
target_token_ids = torch.randint(0,

vllm/config/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2848,6 +2848,14 @@ def __post_init__(self):
28482848
"when cudagraph_mode piecewise cudagraphs is used, "\
28492849
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
28502850

2851+
if self.parallel_config.enable_dbo:
2852+
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
2853+
assert a2a_backend == "deepep_low_latency", \
2854+
"Microbatching currently only supports the deepep_low_latency "\
2855+
f"all2all backend. {a2a_backend} is not supported. To fix set "\
2856+
"the VLLM_ALL2ALL_BACKEND environment variable to "\
2857+
"deepep_low_latency and install the DeepEP kerenls."
2858+
28512859
if not self.instance_id:
28522860
self.instance_id = random_uuid()[:5]
28532861

vllm/config/parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ class ParallelConfig:
137137
disable_custom_all_reduce: bool = False
138138
"""Disable the custom all-reduce kernel and fall back to NCCL."""
139139

140+
enable_dbo: bool = False
141+
"""Enable microbatching for the model executor."""
142+
143+
dbo_decode_token_threshold: int = 32
144+
"""The threshold for microbatching. If the number of tokens in the
145+
request is greater than this threshold, microbatching will be used.
146+
Otherwise, the request will be processed in a single batch."""
147+
140148
ray_workers_use_nsight: bool = False
141149
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
142150

vllm/distributed/device_communicators/all2all.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,4 @@ def get_handle(self, kwargs):
251251
logger.debug("DeepEP all2all args %s", buffer_kwargs)
252252
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
253253
buffer_kwargs, deep_ep.Buffer)
254-
# It is dangerous to set num sms outside this function. num_sms is not
255-
# a part of the hash-key that identifies this object. If we are in a
256-
# situation where we make objects with different num_sms, the hash key
257-
# in get_or_create must be updated.
258-
handle.set_num_sms(self.num_sms)
259254
return handle

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ class EngineArgs:
327327
data_parallel_hybrid_lb: bool = False
328328
data_parallel_backend: str = ParallelConfig.data_parallel_backend
329329
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
330+
enable_dbo: bool = ParallelConfig.enable_dbo
331+
dbo_decode_token_threshold: int = \
332+
ParallelConfig.dbo_decode_token_threshold
330333
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
331334
enable_eplb: bool = ParallelConfig.enable_eplb
332335
expert_placement_strategy: ExpertPlacementStrategy = \
@@ -695,6 +698,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
695698
parallel_group.add_argument(
696699
"--enable-expert-parallel",
697700
**parallel_kwargs["enable_expert_parallel"])
701+
parallel_group.add_argument("--enable-dbo",
702+
**parallel_kwargs["enable_dbo"])
703+
parallel_group.add_argument(
704+
"--dbo-decode-token-threshold",
705+
**parallel_kwargs["dbo_decode_token_threshold"])
698706
parallel_group.add_argument("--enable-eplb",
699707
**parallel_kwargs["enable_eplb"])
700708
parallel_group.add_argument("--eplb-config",
@@ -1339,6 +1347,8 @@ def create_engine_config(
13391347
data_parallel_backend=self.data_parallel_backend,
13401348
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
13411349
enable_expert_parallel=self.enable_expert_parallel,
1350+
enable_dbo=self.enable_dbo,
1351+
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
13421352
enable_eplb=self.enable_eplb,
13431353
eplb_config=self.eplb_config,
13441354
expert_placement_strategy=self.expert_placement_strategy,

vllm/forward_context.py

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
1515
from vllm.logger import init_logger
1616
from vllm.platforms import current_platform
17+
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
1718

1819
if TYPE_CHECKING:
1920
from vllm.attention.backends.abstract import AttentionMetadata
@@ -97,6 +98,53 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
9798
dist.all_reduce(num_tokens_tensor, group=group)
9899
return num_tokens_tensor.cpu()
99100

101+
@staticmethod
102+
def should_ubatch_across_dp(
103+
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
104+
padded_num_tokens_per_ubatch: int, dp_size: int,
105+
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
106+
"""
107+
1. Decides if each DP rank is going to microbatch. Either all ranks
108+
run with microbatching or none of them do. If this function decides
109+
not to run with microbatching. It will "abort" meaning that no padding
110+
information will be returned to the caller. It will return (False, None)
111+
112+
2. Determines the total number of tokens that each rank will run.
113+
All ranks will be padded out so that the run with the same number
114+
of tokens
115+
116+
Returns: tuple[
117+
should_ubatch: Are all DP ranks going to microbatch
118+
num_tokens_after_padding: A tensor containing the total number of
119+
tokens per-microbatch for each DP rank including padding. Will be
120+
None if should_ubatch if False
121+
]
122+
"""
123+
124+
device = current_platform.device_type
125+
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
126+
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
127+
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
128+
tensor[2][dp_rank] = 1 if should_ubatch else 0
129+
130+
from vllm.distributed.parallel_state import get_dp_group
131+
dist.all_reduce(tensor, group=get_dp_group().device_group)
132+
133+
result: bool = bool(torch.all(tensor[2] == 1).item())
134+
if not result:
135+
return result, None
136+
137+
orig_num_tokens_tensor = tensor[0, :]
138+
padded_num_tokens_tensor = tensor[1, :]
139+
140+
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
141+
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
142+
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
143+
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
144+
padded_max_num_tokens)
145+
return False, None
146+
return result, padded_num_tokens_tensor.cpu()
147+
100148
@staticmethod
101149
def make(
102150
parallel_config: ParallelConfig,
@@ -119,14 +167,15 @@ def make(
119167

120168
# If num_tokens_across_dp is None, it will be computed by all_reduce
121169
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
122-
assert (num_tokens_across_dp is None
123-
or num_tokens_across_dp[dp_rank] == batchsize)
170+
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
171+
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
124172
if num_tokens_across_dp is None:
125173
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
126174
batchsize, dp_size, dp_rank)
127175
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
128176
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
129-
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
177+
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
178+
num_tokens_across_dp)
130179

131180
@contextmanager
132181
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
@@ -179,9 +228,12 @@ class ForwardContext:
179228
Type AttentionMetadata for v0,
180229
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
181230
attention layer to its attention metadata
182-
set dynamically for each forward pass
231+
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
232+
for each microbatch.
233+
Set dynamically for each forward pass
183234
"""
184-
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
235+
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
236+
list[dict[str, "AttentionMetadata"]]]
185237
# TODO: remove after making all virtual_engines share the same kv cache
186238
virtual_engine: int # set dynamically for each forward pass
187239
# set dynamically for each forward pass
@@ -191,6 +243,8 @@ class ForwardContext:
191243
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
192244
batch_descriptor: Optional[BatchDescriptor] = None
193245

246+
ubatch_slices: Optional[UBatchSlices] = None
247+
194248
def __post_init__(self):
195249
assert self.cudagraph_runtime_mode in [
196250
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
@@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
208262
return _forward_context
209263

210264

265+
def create_forward_context(
266+
attn_metadata: Any,
267+
vllm_config: VllmConfig,
268+
virtual_engine: int = 0,
269+
dp_metadata: Optional[DPMetadata] = None,
270+
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
271+
batch_descriptor: Optional[BatchDescriptor] = None,
272+
ubatch_slices: Optional[UBatchSlices] = None):
273+
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
274+
static_forward_context,
275+
virtual_engine=virtual_engine,
276+
attn_metadata=attn_metadata,
277+
dp_metadata=dp_metadata,
278+
cudagraph_runtime_mode=cudagraph_runtime_mode,
279+
batch_descriptor=batch_descriptor,
280+
ubatch_slices=ubatch_slices)
281+
282+
283+
@contextmanager
284+
def override_forward_context(forward_context: Optional[ForwardContext]):
285+
"""A context manager that overrides the current forward context.
286+
This is used to override the forward context for a specific
287+
forward pass.
288+
"""
289+
global _forward_context
290+
prev_context = _forward_context
291+
_forward_context = forward_context
292+
try:
293+
yield
294+
finally:
295+
_forward_context = prev_context
296+
297+
211298
@contextmanager
212299
def set_forward_context(
213300
attn_metadata: Any,
@@ -216,7 +303,8 @@ def set_forward_context(
216303
num_tokens: Optional[int] = None,
217304
num_tokens_across_dp: Optional[torch.Tensor] = None,
218305
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
219-
batch_descriptor: Optional[BatchDescriptor] = None):
306+
batch_descriptor: Optional[BatchDescriptor] = None,
307+
ubatch_slices: Optional[UBatchSlices] = None):
220308
"""A context manager that stores the current forward context,
221309
can be attention metadata, etc.
222310
Here we can inject common logic for every model forward pass.
@@ -225,27 +313,22 @@ def set_forward_context(
225313
need_to_track_batchsize = track_batchsize and attn_metadata is not None
226314
if need_to_track_batchsize:
227315
forward_start_time = time.perf_counter()
316+
228317
dp_metadata: Optional[DPMetadata] = None
229318
if vllm_config.parallel_config.data_parallel_size > 1 and (
230319
attn_metadata is not None or num_tokens is not None):
231320
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
232321
attn_metadata, num_tokens or 0,
233322
num_tokens_across_dp)
234323

235-
global _forward_context
236-
prev_context = _forward_context
237-
_forward_context = ForwardContext(
238-
no_compile_layers=vllm_config.compilation_config.
239-
static_forward_context,
240-
virtual_engine=virtual_engine,
241-
attn_metadata=attn_metadata,
242-
dp_metadata=dp_metadata,
243-
cudagraph_runtime_mode=cudagraph_runtime_mode,
244-
batch_descriptor=batch_descriptor,
245-
)
324+
forward_context = create_forward_context(attn_metadata, vllm_config,
325+
virtual_engine, dp_metadata,
326+
cudagraph_runtime_mode,
327+
batch_descriptor, ubatch_slices)
246328

247329
try:
248-
yield
330+
with override_forward_context(forward_context):
331+
yield
249332
finally:
250333
global last_logging_time, batchsize_logging_interval
251334
if need_to_track_batchsize:
@@ -282,5 +365,3 @@ def set_forward_context(
282365
logger.info(("Batchsize forward time stats "
283366
"(batchsize, count, median_time(ms)): %s"),
284367
forward_stats)
285-
286-
_forward_context = prev_context

0 commit comments

Comments
 (0)