Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import torch.nn as nn
from tqdm import tqdm

from vllm.compilation.breakable_cudagraph import (
BreakableCUDAGraphWrapper,
is_breakable_cudagraph_enabled,
)
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
Expand Down Expand Up @@ -116,6 +120,13 @@ def __init__(
)
self._init_candidates()

# Breakable CUDA graph (PW CUDA graph without torch.compile)
self.use_breakable_cg = (
is_breakable_cudagraph_enabled()
and self.cudagraph_mode.has_piecewise_cudagraphs()
)
Comment on lines +124 to +127

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this field seems redundant

The check could be moved inside init_breakable_cg_runner?

Just a suggestion, guessing you might prefer how it already is :)

self.breakable_cg_runner: BreakableCUDAGraphWrapper | None = None

def _init_candidates(self) -> None:
"""Build priority-ordered candidate lists for each token count."""
capture_sizes = self.compilation_config.cudagraph_capture_sizes
Expand Down Expand Up @@ -283,6 +294,20 @@ def run_fullgraph(self, desc: BatchExecutionDescriptor):
get_offloader().sync_prev_onload()
self.graphs[desc].replay()

def init_breakable_cg_runner(self, model: nn.Module) -> None:
if self.breakable_cg_runner is None:
self.breakable_cg_runner = BreakableCUDAGraphWrapper(
model, self.vllm_config
)
self.breakable_cg_runner.graph_pool = self.pool

def run_pw_graph(self, model: nn.Module, model_inputs: dict[str, Any]) -> Any:
if not self.use_breakable_cg:
# Default: Use torch-compiled piecewise cudagraph.
return model(**model_inputs)
assert self.breakable_cg_runner is not None
return self.breakable_cg_runner(**model_inputs)


class ModelCudaGraphManager(CudaGraphManager):
"""CudaGraphManager with model-specific capture and hidden state management."""
Expand Down Expand Up @@ -316,6 +341,8 @@ def capture(
) -> dict[BatchExecutionDescriptor, CapturedAttentionState]:
"""Capture CUDA graphs for model forward pass."""
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
if self.use_breakable_cg:
self.init_breakable_cg_runner(model)

def create_forward_fn(
desc: BatchExecutionDescriptor,
Expand Down Expand Up @@ -370,11 +397,16 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None:
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
model_output = model(**model_inputs)
if cg_mode == CUDAGraphMode.PIECEWISE:
# PIECEWISE graph (compiled PW or breakable, chosen inside
# run_pw_graph).
model_output = self.run_pw_graph(model, model_inputs)
else:
model_output = model(**model_inputs)

if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
# PW CUDA graph (compiled or breakable) internally handles the
# model outputs. No need to keep track of the hidden states.
return None

if self.is_last_pp_rank:
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,17 @@ def execute_model(
skip_compiled=skip_compiled,
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)
if batch_desc.cg_mode == CUDAGraphMode.PIECEWISE:
# Run the PIECEWISE graph (compiled PW cudagraph or breakable
# cudagraph, chosen inside run_pw_graph). cg_mode is only
# PIECEWISE after the cudagraph manager exists.
assert self.cudagraph_manager is not None
model_output = self.cudagraph_manager.run_pw_graph(
self.model, model_inputs
)
else:
# Eager (NONE): call the raw model directly.
model_output = self.model(**model_inputs)

if self.is_last_pp_rank:
if self.use_aux_hidden_state_outputs:
Expand Down
14 changes: 13 additions & 1 deletion vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,22 @@ def run_model(
)
inputs_embeds = self.inputs_embeds[:num_tokens]

ret_hidden_states = self.model(
model_inputs = dict(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
if cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE:
# Draft prefill with PIECEWISE cudagraph (compiled PW or breakable),
# chosen inside run_pw_graph.
assert self.prefill_cudagraph_manager is not None
ret_hidden_states = self.prefill_cudagraph_manager.run_pw_graph(
self.model, model_inputs
)
else:
# Eager (NONE): call the raw model directly.
ret_hidden_states = self.model(**model_inputs)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
Expand Down Expand Up @@ -431,6 +441,8 @@ def capture(
# For PIECEWISE, only the model's compiled regions are captured
# and the rest (compute_logits, gumbel_sample) runs eagerly.
assert self.prefill_cudagraph_manager is not None
if self.prefill_cudagraph_manager.use_breakable_cg:
self.prefill_cudagraph_manager.init_breakable_cg_runner(self.model)
self.prefill_cudagraph_manager.capture(
self.prefill,
attn_states,
Expand Down
Loading