Skip to content

Commit bbddcec

Browse files
zhtshrJasonZhang517gemini-code-assist[bot]helloyongyang
authored
Add disaggregation feature (#893)
# Enable Disaggregation Feature ## Summary This PR introduces a **disaggregation architecture** to LightX2V, enabling distributed deployment of the video generation pipeline across multiple devices or machines. ## What's New ### Core Functionality - **Service Decoupling**: Separate encoder and transformer services that can run independently - **High-Performance Communication**: ZeroMQ and RDMA-based messaging with Mooncake transfer engine - **Flexible Deployment**: Support for single-machine multi-GPU and cross-machine distributed setups ### New Components - `lightx2v/disagg/`: Complete disaggregation package - `conn.py`: Data connection and management - `services/encoder.py`: Encoder service implementation - `services/transformer.py`: Transformer service implementation - `examples/`: Usage examples for WAN I2V and T2V models ## Key Benefits 1. **Resource Flexibility**: Distribute compute-intensive tasks across multiple devices 2. **Scalability**: Easy horizontal scaling for production deployments 3. **Memory Efficiency**: Run large models on hardware-constrained environments 4. **Service-Oriented**: Build microservice-based video generation systems ## Usage Example ```shell python3 lightx2v/disagg/examples/wan_t2v_service.py ``` See `lightx2v/disagg/examples/` for complete working examples. ## Backward Compatibility ✅ This is an **optional feature** that doesn't affect existing functionality: - Default mode preserves current behavior - All existing APIs remain unchanged - Users can opt-in to use disaggregation when needed ## Testing - ✅ Tested with WAN I2V and T2V models - ✅ Verified cross-device communication stability - ✅ Validated accuracy matches single-machine mode ## Files Changed - Added: `lightx2v/disagg/` package with all disaggregation modules - Modified: None (purely additive) ## Future Enhancements - Automatic service discovery - Load balancing across multiple workers - Enhanced monitoring and health checks --- **Type**: Feature **Breaking Changes**: None **Documentation**: Included in `lightx2v/disagg/examples/` --------- Co-authored-by: jasonzhang517 <yzhang298@e.ntu.edu.sg> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: helloyongyang <yongyang1030@163.com>
1 parent 03e5d2c commit bbddcec

File tree

16 files changed

+2290
-0
lines changed

16 files changed

+2290
-0
lines changed

configs/mooncake_config.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"local_hostname": "localhost",
3+
"metadata_server": "P2PHANDSHAKE",
4+
"protocol": "rdma",
5+
"device_name": ""
6+
}

lightx2v/disagg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Disaggregation package initialization

lightx2v/disagg/conn.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import struct
5+
import threading
6+
from dataclasses import dataclass
7+
from enum import Enum
8+
from functools import cache
9+
from typing import Dict, List, Optional, Tuple
10+
11+
import numpy as np
12+
import numpy.typing as npt
13+
import zmq
14+
15+
from lightx2v.disagg.mooncake import MooncakeTransferEngine
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class DisaggregationMode(Enum):
21+
NULL = "null"
22+
ENCODE = "encode"
23+
TRANSFORMER = "transformer"
24+
25+
26+
def group_concurrent_contiguous(src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
27+
src_groups = []
28+
dst_groups = []
29+
current_src = [src_indices[0]]
30+
current_dst = [dst_indices[0]]
31+
32+
for i in range(1, len(src_indices)):
33+
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
34+
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
35+
if src_contiguous and dst_contiguous:
36+
current_src.append(src_indices[i])
37+
current_dst.append(dst_indices[i])
38+
else:
39+
src_groups.append(current_src)
40+
dst_groups.append(current_dst)
41+
current_src = [src_indices[i]]
42+
current_dst = [dst_indices[i]]
43+
44+
src_groups.append(current_src)
45+
dst_groups.append(current_dst)
46+
47+
return src_groups, dst_groups
48+
49+
50+
@dataclass
51+
class DataArgs:
52+
sender_engine_rank: int
53+
receiver_engine_rank: int
54+
data_ptrs: list[int]
55+
data_lens: list[int]
56+
data_item_lens: list[int]
57+
ib_device: Optional[str] = None
58+
59+
60+
class DataPoll:
61+
Failed = 0
62+
Bootstrapping = 1
63+
WaitingForInput = 2
64+
Transferring = 3
65+
Success = 4
66+
67+
68+
RequestPoolType = Dict[int, List[int]]
69+
WaitingPoolType = Dict[int, Tuple[str, list[int]]]
70+
DATASENDER_POLLING_PORT = 17788
71+
DATARECEIVER_POLLING_PORT = 27788
72+
73+
74+
class DataManager:
75+
# TODO: make it general and support multiple transfer backend before merging
76+
def __init__(self, args: DataArgs, disaggregation_mode: DisaggregationMode):
77+
self.engine = MooncakeTransferEngine()
78+
self.data_args = args
79+
self.disaggregation_mode = disaggregation_mode
80+
self.request_pool: RequestPoolType = {}
81+
self.request_status: Dict[int, DataPoll] = {}
82+
self.server_socket = zmq.Context().socket(zmq.PULL)
83+
self.register_buffer_to_engine()
84+
if self.disaggregation_mode == DisaggregationMode.ENCODE:
85+
self.waiting_pool: WaitingPoolType = {}
86+
self.transfer_event = threading.Event()
87+
self.start_encode_thread()
88+
elif self.disaggregation_mode == DisaggregationMode.TRANSFORMER:
89+
self.start_transformer_thread()
90+
else:
91+
raise ValueError(f"Unsupported DisaggregationMode: {self.disaggregation_mode}")
92+
93+
def register_buffer_to_engine(self):
94+
for data_ptr, data_len in zip(self.data_args.data_ptrs, self.data_args.data_lens):
95+
self.engine.register(data_ptr, data_len)
96+
97+
@cache
98+
def _connect(self, endpoint: str):
99+
socket = zmq.Context().socket(zmq.PUSH)
100+
socket.connect(endpoint)
101+
return socket
102+
103+
def send_data(
104+
self,
105+
mooncake_session_id: str,
106+
encode_data_ptrs: List[int],
107+
transformer_ptrs: list[int],
108+
):
109+
tensor_num = int(len(self.data_args.data_ptrs))
110+
for tensor_id in range(tensor_num):
111+
encode_addr = encode_data_ptrs[tensor_id]
112+
item_len = self.data_args.data_item_lens[tensor_id]
113+
transformer_addr = transformer_ptrs[tensor_id]
114+
115+
# TODO: mooncake transfer engine can do async transfer. Do async later
116+
status = self.engine.transfer_sync(
117+
mooncake_session_id,
118+
encode_addr,
119+
transformer_addr,
120+
item_len,
121+
)
122+
if status != 0:
123+
return status
124+
return 0
125+
126+
def sync_status_to_transformer_endpoint(self, remote: str, room: int):
127+
if ":" in remote:
128+
remote = remote.split(":")[0]
129+
self._connect("tcp://" + remote + ":" + str(DATARECEIVER_POLLING_PORT + self.data_args.receiver_engine_rank)).send_multipart(
130+
[
131+
str(room).encode("ascii"),
132+
str(self.request_status[room]).encode("ascii"),
133+
]
134+
)
135+
136+
def start_encode_thread(self):
137+
sender_rank_port = DATASENDER_POLLING_PORT + self.data_args.sender_engine_rank
138+
logger.info("Encoder sender_rank_port=%s", sender_rank_port)
139+
self.server_socket.bind("tcp://*:" + str(sender_rank_port))
140+
141+
def encode_thread():
142+
while True:
143+
(
144+
endpoint,
145+
mooncake_session_id,
146+
bootstrap_room,
147+
transformer_ptrs,
148+
) = self.server_socket.recv_multipart()
149+
if bootstrap_room.decode("ascii") == "None":
150+
continue
151+
endpoint = endpoint.decode("ascii")
152+
mooncake_session_id = mooncake_session_id.decode("ascii")
153+
bootstrap_room = int(bootstrap_room.decode("ascii"))
154+
transformer_ptrs = list(struct.unpack(f"{len(transformer_ptrs) // 8}Q", transformer_ptrs))
155+
logger.info(
156+
"Encoder received ZMQ: endpoint=%s session_id=%s room=%s transformer_ptrs=%s",
157+
endpoint,
158+
mooncake_session_id,
159+
bootstrap_room,
160+
transformer_ptrs,
161+
)
162+
self.waiting_pool[bootstrap_room] = (
163+
endpoint,
164+
mooncake_session_id,
165+
transformer_ptrs,
166+
)
167+
self.transfer_event.set()
168+
169+
threading.Thread(target=encode_thread).start()
170+
171+
def transfer_thread():
172+
while True:
173+
self.transfer_event.wait()
174+
self.transfer_event.clear()
175+
bootstrap_room_ready = self.request_pool.keys()
176+
bootstrap_room_request = self.waiting_pool.keys()
177+
for room in list(bootstrap_room_request):
178+
if room not in list(bootstrap_room_ready):
179+
continue
180+
status = DataPoll.Transferring
181+
self.request_status[room] = status
182+
(
183+
endpoint,
184+
mooncake_session_id,
185+
transformer_ptrs,
186+
) = self.waiting_pool.pop(room)
187+
self.sync_status_to_transformer_endpoint(endpoint, room)
188+
encode_data_ptrs = self.request_pool.pop(room)
189+
ret = self.send_data(
190+
mooncake_session_id,
191+
encode_data_ptrs,
192+
transformer_ptrs,
193+
)
194+
if ret != 0:
195+
status = DataPoll.Failed
196+
self.sync_status_to_transformer_endpoint(endpoint, room)
197+
continue
198+
status = DataPoll.Success
199+
self.request_status[room] = status
200+
self.sync_status_to_transformer_endpoint(endpoint, room)
201+
202+
threading.Thread(target=transfer_thread).start()
203+
204+
def start_transformer_thread(self):
205+
receiver_rank_port = DATARECEIVER_POLLING_PORT + self.data_args.receiver_engine_rank
206+
self.server_socket.bind("tcp://*:" + str(receiver_rank_port))
207+
208+
def transformer_thread():
209+
while True:
210+
(bootstrap_room, status) = self.server_socket.recv_multipart()
211+
status = int(status.decode("ascii"))
212+
bootstrap_room = int(bootstrap_room.decode("ascii"))
213+
self.request_status[bootstrap_room] = status
214+
215+
threading.Thread(target=transformer_thread).start()
216+
217+
def enqueue_request(
218+
self,
219+
bootstrap_room: int,
220+
data_ptrs: List[int],
221+
):
222+
self.request_pool[bootstrap_room] = data_ptrs
223+
self.request_status[bootstrap_room] = DataPoll.WaitingForInput
224+
if self.disaggregation_mode == DisaggregationMode.ENCODE:
225+
self.transfer_event.set()
226+
227+
def check_status(self, bootstrap_room: int):
228+
if self.disaggregation_mode == DisaggregationMode.TRANSFORMER and self.request_status[bootstrap_room] == DataPoll.Success:
229+
if bootstrap_room in self.request_pool:
230+
self.request_pool.pop(bootstrap_room)
231+
232+
return self.request_status[bootstrap_room]
233+
234+
def set_status(self, bootstrap_room: int, status: DataPoll):
235+
self.request_status[bootstrap_room] = status
236+
237+
def get_localhost(self):
238+
return self.engine.get_localhost()
239+
240+
def get_session_id(self):
241+
return self.engine.get_session_id()
242+
243+
244+
class DataSender:
245+
def __init__(self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: int):
246+
self.data_mgr = mgr
247+
self.bootstrap_room = bootstrap_room
248+
self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput)
249+
250+
def init(self, num_data_indices: int):
251+
self.num_data_indices = num_data_indices
252+
253+
def send(self, data_ptrs: List[int]):
254+
self.data_mgr.enqueue_request(self.bootstrap_room, data_ptrs)
255+
256+
def poll(self) -> DataPoll:
257+
return self.data_mgr.check_status(self.bootstrap_room)
258+
259+
def failure_exception(self):
260+
raise Exception("Fake DataSender Exception")
261+
262+
263+
class DataReceiver:
264+
def __init__(self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None):
265+
self.bootstrap_room = bootstrap_room
266+
self.bootstrap_addr = bootstrap_addr
267+
self.data_mgr = mgr
268+
self.encode_server_url = bootstrap_addr.split(":")[0] + ":" + str(DATASENDER_POLLING_PORT + self.data_mgr.data_args.sender_engine_rank)
269+
logger.info("DataReceiver encode_server_url=%s", self.encode_server_url)
270+
self.transformer_ip = self.data_mgr.get_localhost()
271+
self.session_id = self.data_mgr.get_session_id()
272+
self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput)
273+
274+
@cache
275+
def _connect(self, endpoint: str):
276+
socket = zmq.Context().socket(zmq.PUSH)
277+
socket.connect(endpoint)
278+
return socket
279+
280+
def init(self):
281+
packed_data_ptrs = b"".join(struct.pack("Q", ptr) for ptr in self.data_mgr.data_args.data_ptrs)
282+
self.data_mgr.enqueue_request(self.bootstrap_room, packed_data_ptrs)
283+
self._connect("tcp://" + self.encode_server_url).send_multipart(
284+
[
285+
self.transformer_ip.encode("ascii"),
286+
self.session_id.encode("ascii"),
287+
str(self.bootstrap_room).encode("ascii"),
288+
packed_data_ptrs,
289+
]
290+
)
291+
292+
def poll(self) -> DataPoll:
293+
return self.data_mgr.check_status(self.bootstrap_room)
294+
295+
def failure_exception(self):
296+
raise Exception("Fake DataReceiver Exception")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import zmq
3+
from mooncake.engine import TransferEngine
4+
5+
6+
def main():
7+
# Initialize ZMQ context and socket
8+
context = zmq.Context()
9+
socket = context.socket(zmq.PULL)
10+
socket.connect(f"tcp://localhost:5555")
11+
12+
# Wait for buffer info from server
13+
print("Waiting for server buffer information...")
14+
buffer_info = socket.recv_json()
15+
server_session_id = buffer_info["session_id"]
16+
server_ptr = buffer_info["ptr"]
17+
server_len = buffer_info["len"]
18+
print(f"Received server info - Session ID: {server_session_id}")
19+
print(f"Server buffer address: {server_ptr}, length: {server_len}")
20+
21+
# Initialize client engine
22+
HOSTNAME = "localhost" # localhost for simple demo
23+
METADATA_SERVER = "P2PHANDSHAKE" # [ETCD_SERVER_URL, P2PHANDSHAKE, ...]
24+
PROTOCOL = "rdma" # [rdma, tcp, ...]
25+
DEVICE_NAME = "" # auto discovery if empty
26+
27+
client_engine = TransferEngine()
28+
client_engine.initialize(HOSTNAME, METADATA_SERVER, PROTOCOL, DEVICE_NAME)
29+
session_id = f"{HOSTNAME}:{client_engine.get_rpc_port()}"
30+
31+
# Allocate and initialize client buffer (1MB)
32+
client_buffer = torch.ones(1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")) # Fill with ones
33+
client_ptr = client_buffer.data_ptr()
34+
client_len = client_buffer.element_size() * client_buffer.nelement()
35+
36+
# Register memory with Mooncake
37+
if PROTOCOL == "rdma":
38+
ret_value = client_engine.register_memory(client_ptr, client_len)
39+
if ret_value != 0:
40+
print("Mooncake memory registration failed.")
41+
raise RuntimeError("Mooncake memory registration failed.")
42+
43+
print(f"Client initialized with session ID: {session_id}")
44+
45+
# Transfer data from client to server
46+
print("Transferring data to server...")
47+
for _ in range(10):
48+
ret = client_engine.transfer_sync_write(
49+
server_session_id,
50+
client_ptr,
51+
server_ptr,
52+
min(client_len, server_len), # Transfer minimum of both lengths
53+
)
54+
55+
if ret >= 0:
56+
print("Transfer successful!")
57+
else:
58+
print("Transfer failed!")
59+
60+
# Cleanup
61+
if PROTOCOL == "rdma":
62+
ret_value = client_engine.unregister_memory(client_ptr)
63+
if ret_value != 0:
64+
print("Mooncake memory deregistration failed.")
65+
raise RuntimeError("Mooncake memory deregistration failed.")
66+
67+
socket.close()
68+
context.term()
69+
70+
71+
if __name__ == "__main__":
72+
main()

0 commit comments

Comments
 (0)