Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@
EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqClient,
ZmqIpcServer,
ZmqTcpServer,
)
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, envs, llm_logger

Expand Down Expand Up @@ -571,10 +573,21 @@ def _fetch_request():
def start_zmq_service(self, api_server_pid=None):
if api_server_pid is None:
return
self.api_server_pid = api_server_pid
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()

if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
self.external_adapter = InternalAdapter(
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
)
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()

time.sleep(3)
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
self.insert_task_to_scheduler_thread.start()
Expand All @@ -588,9 +601,9 @@ def _insert_zmq_task_to_scheduler(self):
try:
block = True if len(added_requests) == 0 else False
if not self.cfg.model_config.enable_mm:
err, data = self.zmq_server.receive_json_once(block)
err, data = self.recv_request_server.receive_json_once(block)
else:
err, data = self.zmq_server.receive_pyobj_once(block)
err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
break
Expand Down Expand Up @@ -644,7 +657,7 @@ def _insert_zmq_task_to_scheduler(self):
)
# Since the request is not in scheduler
# Send result by zmq directly
self.zmq_server.send_multipart(request_id, [error_result])
self.send_response_server.send_response(request_id, [error_result])
except Exception as e:
llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, "
Expand All @@ -662,7 +675,7 @@ def _zmq_send_generated_tokens(self):
time.sleep(0.005)
continue
for request_id, contents in results.items():
self.zmq_server.send_multipart(request_id, contents)
self.send_response_server.send_response(request_id, contents)

except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
Expand Down Expand Up @@ -110,7 +110,7 @@ def create_zmq_client(self, model, mode):
"""
Create a ZMQ client.
"""
self.zmq_client = ZmqClient(model, mode)
self.zmq_client = ZmqIpcClient(model, mode)
self.zmq_client.connect()

async def format_and_add_data(self, prompts: dict):
Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
# force disable default chunked prefill
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
}


Expand Down
13 changes: 11 additions & 2 deletions fastdeploy/inter_communicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal, shared_memory_exists
from .zmq_client import ZmqClient
from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer

__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"]
__all__ = [
"ZmqIpcClient",
"IPCSignal",
"EngineWorkerQueue",
"EngineCacheQueue",
"ZmqTcpServer",
"ZmqIpcServer",
"shared_memory_exists",
]
197 changes: 33 additions & 164 deletions fastdeploy/inter_communicator/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,209 +14,78 @@
# limitations under the License.
"""

import os
import threading
import time
import traceback
from abc import ABC, abstractmethod

import msgpack
import zmq

from fastdeploy import envs
from fastdeploy.utils import zmq_client_logger


class ZmqClient:
class ZmqClientBase(ABC):
"""
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
"""

def __init__(self, name, mode):
self.context = zmq.Context(4)
self.socket = self.context.socket(mode)
self.file_name = f"/dev/shm/{name}.socket"
self.router_path = f"/dev/shm/router_{name}.ipc"
def __init__(self):
pass

self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
@abstractmethod
def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass

self.mutex = threading.Lock()
self.req_dict = dict()
self.router = None
self.poller = None
self.running = True
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()

@abstractmethod
def connect(self):
"""
Connect to the server using the file name specified in the constructor.
"""
self.socket.connect(f"ipc://{self.file_name}")

def start_server(self):
"""
Start the server using the file name specified in the constructor.
"""
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)

def create_router(self):
"""
Create a ROUTER socket and bind it to the specified router path.
"""
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}")
zmq_client_logger.info(f"router path: {self.router_path}")
pass

def send_json(self, data):
"""
Send a JSON-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_json(data)

def recv_json(self):
"""
Receive a JSON-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_json()

def send_pyobj(self, data):
"""
Send a Pickle-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_pyobj(data)

def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_pyobj()

def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result

def send_multipart(self, req_id, data):
"""
Send a multipart message to the router socket.
"""
if self.router is None:
raise RuntimeError("Router socket not created. Call create_router() first.")

while self.running:
with self.mutex:
if req_id not in self.req_dict:
try:
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
if self.req_dict[req_id] == -1:
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
return
try:
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(data)
else:
result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b"", result])
zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except zmq.ZMQError as e:
zmq_client_logger.error(f"[{req_id}] zmq error: {e}")
self.req_dict[req_id] = -1
except Exception as e:
zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")

if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}")

def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
return str(e), None

def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
return str(e), None

def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}")

def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return

self.running = False
zmq_client_logger.info("Closing ZMQ connection...")
try:
if hasattr(self, "socket") and not self.socket.closed:
self.socket.close()

if self.router is not None and not self.router.closed:
self.router.close()

if not self.context.closed:
self.context.term()
class ZmqIpcClient(ZmqClientBase):
def __init__(self, name, mode):
self.name = name
self.mode = mode
self.file_name = f"/dev/shm/{name}.socket"
self.context = zmq.Context()
self.socket = self.context.socket(self.mode)

self._clear_ipc(self.file_name)
self._clear_ipc(self.router_path)
except Exception as e:
zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}")
return
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.context = zmq.Context()
return self.context.socket(self.mode)

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def connect(self):
self._ensure_socket()
self.socket.connect(f"ipc://{self.file_name}")
Loading
Loading