3636}
3737
3838
39+ def mpi_publish_name ():
40+ port_name = None
41+ try :
42+ port_name = MPI .Open_port ()
43+ MPI .Publish_name ('my_port' , port_name )
44+ except MPI .Exception as e :
45+ print (f"Error publishing port name: { e } " )
46+ raise e
47+ except Exception as e :
48+ print (f"Unexpected error publishing port name: { e } " )
49+ raise e
50+
51+ return port_name
52+
53+
54+ def mpi_initialize_intercomm (port_name ):
55+ intercomm = None
56+ try :
57+ intercomm = MPI .COMM_SELF .Accept (port_name )
58+ except MPI .Exception as e :
59+ print (f"Error accepting intercomm: { e } " , flush = True )
60+ raise
61+ except Exception as e :
62+ print (f"Unexpected error accepting intercomm: { e } " , flush = True )
63+ raise
64+ return intercomm
65+
66+
67+ def mpi_send_termination_request (intercomm ):
68+ if intercomm is not None :
69+ # Send termination requests
70+ intercomm .send (None , dest = 0 , tag = MPI_REQUEST )
71+ intercomm .send (None , dest = 1 , tag = MPI_REQUEST )
72+ print ("Sent termination requests to the workers." )
73+
74+
3975def model_path (model_name ):
4076 llm_models_root = os .environ ["LLM_MODELS_ROOT" ]
4177 for name , path in MODEL_PATHS .items ():
@@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config,
4884 model_name , rank ):
4985 assert isinstance (pytorch_config , dict )
5086 print (f"Running worker { rank } " )
51- port_name = MPI .Lookup_name ('my_port' )
52- intercomm = MPI .COMM_WORLD .Connect (port_name )
87+ try :
88+ port_name = MPI .Lookup_name ('my_port' )
89+ intercomm = MPI .COMM_WORLD .Connect (port_name )
90+ except MPI .Exception as e :
91+ print (f"Error publishing port name: { e } " )
92+ raise e
93+ except Exception as e :
94+ print (f"Unexpected error publishing port name: { e } " )
95+ raise e
5396
5497 session = MPI .COMM_WORLD .Split (color = rank , key = 0 )
5598 set_mpi_comm (session )
@@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
139182 zip (kv_cache_configs , cache_transceiver_configs , worker_pytorch_configs ,
140183 model_names , ranks ))
141184
142- port_name = MPI .Open_port ()
143- MPI .Publish_name ('my_port' , port_name )
185+ port_name = mpi_publish_name ()
144186
145187 with MPIPoolExecutor (max_workers = 2 , env = {"UCX_TLS" : "^ib" }) as executor :
146188 futures = []
@@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
152194 print (f"Error in worker { worker_arg } : { e } " )
153195 raise e
154196
197+ intercomm = None
155198 try :
156- print ("Launched all the workers." )
157- intercomm = MPI . COMM_SELF . Accept (port_name )
199+ print ("Launched all the workers." , flush = True )
200+ intercomm = mpi_initialize_intercomm (port_name )
158201
159202 for _ in range (2 ):
160203 intercomm .recv (tag = MPI_READY )
@@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
187230 output = responses [0 ]
188231 assert output [0 ].text == expected_output
189232 assert output [0 ].token_ids == expected_output_ids
190-
233+ except Exception as e :
234+ print (f"Exception encountered: { e } " , flush = True )
235+ raise e
191236 finally :
192- # Send termination requests
193- intercomm .send (None , dest = 0 , tag = MPI_REQUEST )
194- intercomm .send (None , dest = 1 , tag = MPI_REQUEST )
195- print ("Sent termination requests to the workers." )
237+ print ("Sending termination request" , flush = True )
238+ mpi_send_termination_request (intercomm )
196239
197240 # Wait for all futures to complete
241+ print ("Waiting for all workers to terminate. " , flush = True )
198242 for future in futures :
199243 future .result ()
200244 print ("All workers terminated." )
@@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
282326 zip (kv_cache_configs , cache_transceiver_configs , worker_pytorch_configs ,
283327 model_names , ranks ))
284328
285- port_name = MPI .Open_port ()
286- MPI .Publish_name ('my_port' , port_name )
329+ port_name = mpi_publish_name ()
287330
288331 prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity."
289332
@@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
297340 print (f"Error in worker { worker_arg } : { e } " )
298341 raise e
299342
343+ intercomm = None
300344 try :
301345 print ("Launched all the workers." )
302- intercomm = MPI . COMM_SELF . Accept (port_name )
346+ intercomm = mpi_initialize_intercomm (port_name )
303347
304348 for _ in range (2 ):
305349 intercomm .recv (tag = MPI_READY )
@@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
334378 intercomm .send (requests , dest = 1 , tag = MPI_REQUEST )
335379 output = intercomm .recv (source = 1 , tag = MPI_RESULT )
336380
381+ except MPI .Exception as e :
382+ print (f"MPI Error" )
383+ raise e
337384 finally :
338- # Send termination requests
339- intercomm .send (None , dest = 0 , tag = MPI_REQUEST )
340- intercomm .send (None , dest = 1 , tag = MPI_REQUEST )
341- print ("Sent termination requests to the workers." )
385+ mpi_send_termination_request (intercomm )
342386
343387 # Wait for all futures to complete
344388 for future in futures :
@@ -385,8 +429,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
385429 zip (kv_cache_configs , cache_transceiver_configs , worker_pytorch_configs ,
386430 model_names , ranks ))
387431
388- port_name = MPI .Open_port ()
389- MPI .Publish_name ('my_port' , port_name )
432+ port_name = mpi_publish_name ()
390433
391434 prompt = "What is the capital of Germany?"
392435
@@ -400,9 +443,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
400443 print (f"Error in worker { worker_arg } : { e } " )
401444 raise e
402445
446+ intercomm = None
403447 try :
404448 print ("Launched all the workers." )
405- intercomm = MPI . COMM_SELF . Accept (port_name )
449+ intercomm = mpi_initialize_intercomm (port_name )
406450
407451 for _ in range (2 ):
408452 intercomm .recv (tag = MPI_READY )
@@ -436,11 +480,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
436480 intercomm .send (requests , dest = 1 , tag = MPI_REQUEST )
437481 output = intercomm .recv (source = 1 , tag = MPI_RESULT )
438482
483+ except MPI .Exception as e :
484+ print (f"MPI Error" )
485+ raise e
439486 finally :
440- # Send termination requests
441- intercomm .send (None , dest = 0 , tag = MPI_REQUEST )
442- intercomm .send (None , dest = 1 , tag = MPI_REQUEST )
443- print ("Sent termination requests to the workers." )
487+ mpi_send_termination_request (intercomm )
444488
445489 # Wait for all futures to complete
446490 for future in futures :
0 commit comments