Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,42 @@
}


def mpi_publish_name():
port_name = None
try:
port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
except MPI.Exception as e:
print(f"Error publishing port name: {e}")
raise e
except Exception as e:
print(f"Unexpected error publishing port name: {e}")
raise e

return port_name


def mpi_initialize_intercomm(port_name):
intercomm = None
try:
intercomm = MPI.COMM_SELF.Accept(port_name)
except MPI.Exception as e:
print(f"Error accepting intercomm: {e}", flush=True)
raise
except Exception as e:
print(f"Unexpected error accepting intercomm: {e}", flush=True)
raise
return intercomm


def mpi_send_termination_request(intercomm):
if intercomm is not None:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")


def model_path(model_name):
llm_models_root = os.environ["LLM_MODELS_ROOT"]
for name, path in MODEL_PATHS.items():
Expand All @@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config,
model_name, rank):
assert isinstance(pytorch_config, dict)
print(f"Running worker {rank}")
port_name = MPI.Lookup_name('my_port')
intercomm = MPI.COMM_WORLD.Connect(port_name)
try:
port_name = MPI.Lookup_name('my_port')
intercomm = MPI.COMM_WORLD.Connect(port_name)
except MPI.Exception as e:
print(f"Error publishing port name: {e}")
raise e
except Exception as e:
print(f"Unexpected error publishing port name: {e}")
raise e

session = MPI.COMM_WORLD.Split(color=rank, key=0)
set_mpi_comm(session)
Expand Down Expand Up @@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))

port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
port_name = mpi_publish_name()

with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor:
futures = []
Expand All @@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
print(f"Error in worker {worker_arg}: {e}")
raise e

intercomm = None
try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)
print("Launched all the workers.", flush=True)
intercomm = mpi_initialize_intercomm(port_name)

for _ in range(2):
intercomm.recv(tag=MPI_READY)
Expand Down Expand Up @@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
output = responses[0]
assert output[0].text == expected_output
assert output[0].token_ids == expected_output_ids

except Exception as e:
print(f"Exception encountered: {e}", flush=True)
raise e
finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
print("Sending termination request", flush=True)
mpi_send_termination_request(intercomm)

# Wait for all futures to complete
print("Waiting for all workers to terminate. ", flush=True)
for future in futures:
future.result()
print("All workers terminated.")
Expand Down Expand Up @@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))

port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
port_name = mpi_publish_name()

prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity."

Expand All @@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
print(f"Error in worker {worker_arg}: {e}")
raise e

intercomm = None
try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)
intercomm = mpi_initialize_intercomm(port_name)

for _ in range(2):
intercomm.recv(tag=MPI_READY)
Expand Down Expand Up @@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
output = intercomm.recv(source=1, tag=MPI_RESULT)

except MPI.Exception as e:
print(f"MPI Error")
raise e
finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
mpi_send_termination_request(intercomm)

# Wait for all futures to complete
for future in futures:
Expand Down Expand Up @@ -385,8 +429,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))

port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
port_name = mpi_publish_name()

prompt = "What is the capital of Germany?"

Expand All @@ -400,9 +443,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
print(f"Error in worker {worker_arg}: {e}")
raise e

intercomm = None
try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)
intercomm = mpi_initialize_intercomm(port_name)

for _ in range(2):
intercomm.recv(tag=MPI_READY)
Expand Down Expand Up @@ -436,11 +480,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
output = intercomm.recv(source=1, tag=MPI_RESULT)

except MPI.Exception as e:
print(f"MPI Error")
raise e
finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
mpi_send_termination_request(intercomm)

# Wait for all futures to complete
for future in futures:
Expand Down