Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
kill_process_tree)

if TYPE_CHECKING:
import numpy as np

from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
Expand Down Expand Up @@ -97,20 +99,31 @@ def __repr__(self):


class CpuGpuBuffer:
"""Buffer to easily copy tensors between CPU and GPU."""

def __init__(
self,
*args,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.cpu = torch.zeros(*args,
with_numpy: bool = True,
) -> None:
self.cpu = torch.zeros(*size,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = self.cpu.to(device)
self.np: np.ndarray
# To keep type hints simple (avoiding generics and subclasses), we
# only conditionally create the numpy array attribute. This can cause
# AttributeError if `self.np` is accessed when `with_numpy=False`.
if with_numpy:
if dtype == torch.bfloat16:
raise ValueError(
"Bfloat16 torch tensors cannot be directly cast to a "
"numpy array, so call CpuGpuBuffer with with_numpy=False")
self.np = self.cpu.numpy()

def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
if n is None:
Expand Down
30 changes: 20 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,13 @@ def __init__(
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
numpy=False)

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
Expand Down Expand Up @@ -321,11 +324,18 @@ def __init__(
device="cpu",
pin_memory=self.pin_memory)

def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
numpy: bool = True) -> CpuGpuBuffer:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
pin_memory=self.pin_memory,
with_numpy=numpy)

def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
Expand Down Expand Up @@ -1521,11 +1531,11 @@ def execute_model(
)

# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(
inputs_embeds_scheduled)

input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
model_kwargs = {
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
Expand Down Expand Up @@ -2318,7 +2328,7 @@ def _dummy_run(
num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = {
**self._init_model_kwargs(num_tokens),
**self._dummy_mm_kwargs(num_reqs),
Expand Down