Skip to content

Commit 305a1cc

Browse files
refactor: Turn GPUModelRunner.inputs_embeds to a CpuGpuBuffer (vllm-project#24345)
Signed-off-by: Andrew Sansom <[email protected]>
1 parent 6d6c6b0 commit 305a1cc

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

vllm/v1/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
kill_process_tree)
2020

2121
if TYPE_CHECKING:
22+
import numpy as np
23+
2224
from vllm.v1.engine.coordinator import DPCoordinator
2325
from vllm.v1.engine.utils import (CoreEngineActorManager,
2426
CoreEngineProcManager)
@@ -97,20 +99,31 @@ def __repr__(self):
9799

98100

99101
class CpuGpuBuffer:
102+
"""Buffer to easily copy tensors between CPU and GPU."""
100103

101104
def __init__(
102105
self,
103-
*args,
106+
*size: Union[int, torch.SymInt],
104107
dtype: torch.dtype,
105108
device: torch.device,
106109
pin_memory: bool,
107-
):
108-
self.cpu = torch.zeros(*args,
110+
with_numpy: bool = True,
111+
) -> None:
112+
self.cpu = torch.zeros(*size,
109113
dtype=dtype,
110114
device="cpu",
111115
pin_memory=pin_memory)
112-
self.np = self.cpu.numpy()
113116
self.gpu = self.cpu.to(device)
117+
self.np: np.ndarray
118+
# To keep type hints simple (avoiding generics and subclasses), we
119+
# only conditionally create the numpy array attribute. This can cause
120+
# AttributeError if `self.np` is accessed when `with_numpy=False`.
121+
if with_numpy:
122+
if dtype == torch.bfloat16:
123+
raise ValueError(
124+
"Bfloat16 torch tensors cannot be directly cast to a "
125+
"numpy array, so call CpuGpuBuffer with with_numpy=False")
126+
self.np = self.cpu.numpy()
114127

115128
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
116129
if n is None:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,13 @@ def __init__(
303303
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
304304
dtype=torch.int32)
305305
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
306-
self.inputs_embeds = torch.zeros(
307-
(self.max_num_tokens, self.hidden_size),
308-
dtype=self.dtype,
309-
device=self.device)
306+
# Because inputs_embeds may be bfloat16 and we don't need a numpy
307+
# version of this tensor, avoid a RuntimeError by not creating a
308+
# numpy buffer.
309+
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
310+
self.hidden_size,
311+
dtype=self.dtype,
312+
numpy=False)
310313

311314
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
312315
if self.uses_mrope:
@@ -374,11 +377,18 @@ def __init__(
374377
device="cpu",
375378
pin_memory=self.pin_memory)
376379

377-
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
378-
return CpuGpuBuffer(*args,
380+
def _make_buffer(self,
381+
*size: Union[int, torch.SymInt],
382+
dtype: torch.dtype,
383+
numpy: bool = True) -> CpuGpuBuffer:
384+
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
385+
# if a bfloat16 buffer is needed without a corresponding numpy array,
386+
# don't bother instantiating the numpy array.
387+
return CpuGpuBuffer(*size,
379388
dtype=dtype,
380389
device=self.device,
381-
pin_memory=self.pin_memory)
390+
pin_memory=self.pin_memory,
391+
with_numpy=numpy)
382392

383393
def _init_model_kwargs(self, num_tokens: int):
384394
model_kwargs = dict[str, Any]()
@@ -1645,11 +1655,11 @@ def execute_model(
16451655
)
16461656

16471657
# TODO(woosuk): Avoid the copy. Optimize.
1648-
self.inputs_embeds[:num_scheduled_tokens].copy_(
1658+
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(
16491659
inputs_embeds_scheduled)
16501660

16511661
input_ids = None
1652-
inputs_embeds = self.inputs_embeds[:num_input_tokens]
1662+
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
16531663
model_kwargs = {
16541664
**self._init_model_kwargs(num_scheduled_tokens),
16551665
**self._extract_mm_kwargs(scheduler_output),
@@ -2484,7 +2494,7 @@ def _dummy_run(
24842494
num_scheduled_tokens, remove_lora):
24852495
if self.supports_mm_inputs:
24862496
input_ids = None
2487-
inputs_embeds = self.inputs_embeds[:num_tokens]
2497+
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
24882498
model_kwargs = {
24892499
**self._init_model_kwargs(num_tokens),
24902500
**self._dummy_mm_kwargs(num_reqs),

0 commit comments

Comments
 (0)