Skip to content

Commit d506f62

Browse files
pcastonguayyuanjingx87
authored andcommitted
[https://nvbugs/5470840][fix] Disaggregated unit test MPI Init handling (#7139)
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent 88b1446 commit d506f62

File tree

1 file changed

+69
-25
lines changed

1 file changed

+69
-25
lines changed

tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,42 @@
3636
}
3737

3838

39+
def mpi_publish_name():
40+
port_name = None
41+
try:
42+
port_name = MPI.Open_port()
43+
MPI.Publish_name('my_port', port_name)
44+
except MPI.Exception as e:
45+
print(f"Error publishing port name: {e}")
46+
raise e
47+
except Exception as e:
48+
print(f"Unexpected error publishing port name: {e}")
49+
raise e
50+
51+
return port_name
52+
53+
54+
def mpi_initialize_intercomm(port_name):
55+
intercomm = None
56+
try:
57+
intercomm = MPI.COMM_SELF.Accept(port_name)
58+
except MPI.Exception as e:
59+
print(f"Error accepting intercomm: {e}", flush=True)
60+
raise
61+
except Exception as e:
62+
print(f"Unexpected error accepting intercomm: {e}", flush=True)
63+
raise
64+
return intercomm
65+
66+
67+
def mpi_send_termination_request(intercomm):
68+
if intercomm is not None:
69+
# Send termination requests
70+
intercomm.send(None, dest=0, tag=MPI_REQUEST)
71+
intercomm.send(None, dest=1, tag=MPI_REQUEST)
72+
print("Sent termination requests to the workers.")
73+
74+
3975
def model_path(model_name):
4076
llm_models_root = os.environ["LLM_MODELS_ROOT"]
4177
for name, path in MODEL_PATHS.items():
@@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config,
4884
model_name, rank):
4985
assert isinstance(pytorch_config, dict)
5086
print(f"Running worker {rank}")
51-
port_name = MPI.Lookup_name('my_port')
52-
intercomm = MPI.COMM_WORLD.Connect(port_name)
87+
try:
88+
port_name = MPI.Lookup_name('my_port')
89+
intercomm = MPI.COMM_WORLD.Connect(port_name)
90+
except MPI.Exception as e:
91+
print(f"Error publishing port name: {e}")
92+
raise e
93+
except Exception as e:
94+
print(f"Unexpected error publishing port name: {e}")
95+
raise e
5396

5497
session = MPI.COMM_WORLD.Split(color=rank, key=0)
5598
set_mpi_comm(session)
@@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
139182
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
140183
model_names, ranks))
141184

142-
port_name = MPI.Open_port()
143-
MPI.Publish_name('my_port', port_name)
185+
port_name = mpi_publish_name()
144186

145187
with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor:
146188
futures = []
@@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
152194
print(f"Error in worker {worker_arg}: {e}")
153195
raise e
154196

197+
intercomm = None
155198
try:
156-
print("Launched all the workers.")
157-
intercomm = MPI.COMM_SELF.Accept(port_name)
199+
print("Launched all the workers.", flush=True)
200+
intercomm = mpi_initialize_intercomm(port_name)
158201

159202
for _ in range(2):
160203
intercomm.recv(tag=MPI_READY)
@@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
187230
output = responses[0]
188231
assert output[0].text == expected_output
189232
assert output[0].token_ids == expected_output_ids
190-
233+
except Exception as e:
234+
print(f"Exception encountered: {e}", flush=True)
235+
raise e
191236
finally:
192-
# Send termination requests
193-
intercomm.send(None, dest=0, tag=MPI_REQUEST)
194-
intercomm.send(None, dest=1, tag=MPI_REQUEST)
195-
print("Sent termination requests to the workers.")
237+
print("Sending termination request", flush=True)
238+
mpi_send_termination_request(intercomm)
196239

197240
# Wait for all futures to complete
241+
print("Waiting for all workers to terminate. ", flush=True)
198242
for future in futures:
199243
future.result()
200244
print("All workers terminated.")
@@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
282326
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
283327
model_names, ranks))
284328

285-
port_name = MPI.Open_port()
286-
MPI.Publish_name('my_port', port_name)
329+
port_name = mpi_publish_name()
287330

288331
prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity."
289332

@@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
297340
print(f"Error in worker {worker_arg}: {e}")
298341
raise e
299342

343+
intercomm = None
300344
try:
301345
print("Launched all the workers.")
302-
intercomm = MPI.COMM_SELF.Accept(port_name)
346+
intercomm = mpi_initialize_intercomm(port_name)
303347

304348
for _ in range(2):
305349
intercomm.recv(tag=MPI_READY)
@@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
334378
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
335379
output = intercomm.recv(source=1, tag=MPI_RESULT)
336380

381+
except MPI.Exception as e:
382+
print(f"MPI Error")
383+
raise e
337384
finally:
338-
# Send termination requests
339-
intercomm.send(None, dest=0, tag=MPI_REQUEST)
340-
intercomm.send(None, dest=1, tag=MPI_REQUEST)
341-
print("Sent termination requests to the workers.")
385+
mpi_send_termination_request(intercomm)
342386

343387
# Wait for all futures to complete
344388
for future in futures:
@@ -385,8 +429,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
385429
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
386430
model_names, ranks))
387431

388-
port_name = MPI.Open_port()
389-
MPI.Publish_name('my_port', port_name)
432+
port_name = mpi_publish_name()
390433

391434
prompt = "What is the capital of Germany?"
392435

@@ -400,9 +443,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
400443
print(f"Error in worker {worker_arg}: {e}")
401444
raise e
402445

446+
intercomm = None
403447
try:
404448
print("Launched all the workers.")
405-
intercomm = MPI.COMM_SELF.Accept(port_name)
449+
intercomm = mpi_initialize_intercomm(port_name)
406450

407451
for _ in range(2):
408452
intercomm.recv(tag=MPI_READY)
@@ -436,11 +480,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
436480
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
437481
output = intercomm.recv(source=1, tag=MPI_RESULT)
438482

483+
except MPI.Exception as e:
484+
print(f"MPI Error")
485+
raise e
439486
finally:
440-
# Send termination requests
441-
intercomm.send(None, dest=0, tag=MPI_REQUEST)
442-
intercomm.send(None, dest=1, tag=MPI_REQUEST)
443-
print("Sent termination requests to the workers.")
487+
mpi_send_termination_request(intercomm)
444488

445489
# Wait for all futures to complete
446490
for future in futures:

0 commit comments

Comments
 (0)