Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,9 @@ def __init__(self):
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel:
if (self.buffer is None or self.buffer.numel() < shape_numel or
self.buffer.device != device or self.buffer.dtype != dtype):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change correctly handles buffer recreation for different dtypes or devices in sequential runs (like tests), it is not thread-safe.

The SharedResizableBuffer instances are used as class attributes in FusedMoEModularKernel, making them singletons shared across all MoE layers. If multiple threads execute MoE layers concurrently, they can race to check and reallocate the buffer in get(), leading to memory corruption or incorrect results.

To ensure thread safety, the read-modify-write operation on self.buffer must be atomic. Please protect it with a threading.Lock.

This would involve:

  1. Adding import threading at the top of the file.
  2. Initializing a lock in SharedResizableBuffer.__init__: self._lock = threading.Lock().
  3. Wrapping the logic in get() with the lock: with self._lock: ....

Example of the final SharedResizableBuffer class:

import threading
from math import prod
...

class SharedResizableBuffer:

    def __init__(self):
        self.buffer = None
        self._lock = threading.Lock()

    def get(self, shape: tuple[int, ...], device: torch.device,
            dtype: torch.dtype):
        with self._lock:
            shape_numel = prod(shape)
            if (self.buffer is None or self.buffer.numel() < shape_numel or
                self.buffer.device != device or self.buffer.dtype != dtype):
                self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
            return self.buffer[:shape_numel].view(*shape)

Since these changes are outside the diff, I cannot provide a direct code suggestion. This is a critical issue that should be addressed to prevent race conditions in production.

assert self.buffer.device == device, \
f"Buffer device mismatch: {self.buffer.device} != {device}"
assert self.buffer.dtype == dtype, \
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
return self.buffer[:shape_numel].view(*shape)


Expand Down
Loading