Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tensorrt_llm/serve/openai_disagg_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self,
server_start_timeout_secs: int = 180,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
metrics_interval_secs: int = 0):

self.ctx_servers, self.gen_servers = get_ctx_gen_server_urls(config.server_configs)
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.ctx_router = create_router(
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,22 @@ def test_trtllm_multimodal_benchmark_serving(llm_root, llm_venv):
])


@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device_memory(40000)
@pytest.mark.parametrize("gen_config", ["gen_tp2pp1", "gen_tp1pp2"])
@pytest.mark.parametrize("ctx_config", ["ctx_tp2pp1", "ctx_tp1pp2"])
def test_openai_disagg_multi_nodes_completion(llm_root, llm_venv, ctx_config,
gen_config):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd([
"-m",
"pytest",
str(test_root /
f"_test_disagg_serving_multi_nodes.py::test_completion[{ctx_config}-{gen_config}]"
),
])


### PyTorch examples


Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/qa/llm_function_multinode.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128
test_e2e.py::test_multi_nodes_eval[Qwen3/Qwen3-235B-A22B-tp16-mmlu]
test_e2e.py::test_multi_nodes_eval[Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf-tp16-mmlu]
test_e2e.py::test_multi_nodes_eval[DeepSeek-R1/DeepSeek-R1-0528-FP4-tp16-mmlu]
test_e2e.py::test_openai_disagg_multi_nodes_completion[ctx_tp2pp1-gen_tp2pp1]
test_e2e.py::test_openai_disagg_multi_nodes_completion[ctx_tp1pp2-gen_tp1pp2]
214 changes: 214 additions & 0 deletions tests/unittest/llmapi/apps/_test_disagg_serving_multi_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import os
import socket
import time

import openai
import pytest
import requests

from ..test_llm import get_model_path
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
from .utils import expand_slurm_nodelist

RANK = int(os.environ.get("SLURM_PROCID", 0))
NODE_RANK = int(os.environ.get("SLURM_NODEID", 0))
NODE_LIST = expand_slurm_nodelist(os.environ.get("SLURM_NODELIST", ""))
SLURM_NTASKS_PER_NODE = int(os.environ.get("SLURM_NTASKS_PER_NODE", 1))

pytestmark = pytest.mark.threadleak(enabled=False)

# This test assumes that there are >2 nodes, we run ctx/disagg-server/client on the first node,
# and run gen the second node.

CTX_SERVER_PORT = 8001
GEN_SERVER_PORT = 8002
DISAGG_SERVER_PORT = 8000


# Exclude the current node from the node list, then return other nodes by idx
def get_the_other_host(idx=0):
assert len(NODE_LIST) >= 2
assert socket.gethostname() in NODE_LIST
node_list = NODE_LIST.copy()
node_list.remove(socket.gethostname())
return node_list[idx]


def is_ctx_node():
return NODE_RANK == 0


def is_gen_node():
return NODE_RANK == 1


def is_disagg_node():
return NODE_RANK == 0


# The test is run on multinodes but only the first node's output is used for assertion
def is_pytest_node():
return NODE_RANK == 0


def env():
# Remove MPI related environment variables to isolate the ctx/gen processes
# so that they will not be in the same MPI communicator, otherwise the rank and world_size may mismatch
return {
k: v
for k, v in os.environ.items()
if not ('PMI_' in k or 'OMPI_' in k or 'PMIX_' in k or 'SLURM_' in k)
}


@pytest.fixture(scope="module")
def model_name():
return "llama-3.1-model/Llama-3.1-8B-Instruct"


@pytest.fixture(scope="module", params=['pytorch'], ids=["pytorch"])
def backend(request):
return request.param


@pytest.fixture(
scope="module",
params=[(1, 1), (2, 1), (1, 2)],
ids=lambda tp_pp_size: f'ctx_tp{tp_pp_size[0]}pp{tp_pp_size[1]}')
def ctx_tp_pp_size(request):
return request.param


@pytest.fixture(
scope="module",
params=[(1, 1), (2, 1), (1, 2)],
ids=lambda tp_pp_size: f'gen_tp{tp_pp_size[0]}pp{tp_pp_size[1]}')
def gen_tp_pp_size(request):
return request.param


@pytest.fixture(scope="module")
def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
host = socket.gethostname()
assert host in NODE_LIST
extra_config = {
"cache_transceiver_config": {
"backend": "UCX"
},
"kv_cache_config": {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
},
"disable_overlap_scheduler": True,
}
if is_ctx_node():
print(f"starting ctx_server for rank {RANK} node rank {NODE_RANK}")
model_path = get_model_path(model_name)
tp_size, pp_size = ctx_tp_pp_size
args = ["--tp_size", str(tp_size), "--pp_size", str(pp_size)]
with RemoteOpenAIServer(model_path,
port=CTX_SERVER_PORT,
cli_args=args,
host="0.0.0.0",
env=env(),
llmapi_launch=True,
rank=RANK % SLURM_NTASKS_PER_NODE,
extra_config=extra_config) as server:
yield server
elif is_gen_node():
print(f"starting gen_server for rank {RANK} node rank {NODE_RANK}")
model_path = get_model_path(model_name)
tp_size, pp_size = gen_tp_pp_size
args = ["--tp_size", str(tp_size), "--pp_size", str(pp_size)]
with RemoteOpenAIServer(model_path,
port=GEN_SERVER_PORT,
cli_args=args,
host="0.0.0.0",
env=env(),
rank=RANK % SLURM_NTASKS_PER_NODE,
extra_config=extra_config) as server:
yield server
else:
yield None


def wait_for_endpoint_ready(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
time.sleep(1)
if requests.get(url).status_code == 200:
print(f"endpoint {url} is ready")
return
except Exception as err:
print(f"endpoint {url} is not ready, with exception: {err}")


def wait_for_endpoint_down(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
if requests.get(url).status_code >= 100:
print(
f"endpoint {url} returned status code {requests.get(url).status_code}"
)
time.sleep(1)
except Exception as err:
print(f"endpoint {url} is down, with exception: {err}")
return


@pytest.fixture(scope="module")
def disagg_server(worker: RemoteOpenAIServer):
if is_disagg_node():
print(f"starting disagg_server for rank {RANK} node rank {NODE_RANK}")
ctx_url = f"localhost:8001" # Use localhost since the ctx server is on the same node
# TODO: Hopefully the NODE_LIST is ordered by NODE_RANK, this test is only expected to run with 2 nodes now
# We need to test with 4 nodes or more in the future, which should be easier with service discovery
gen_url = f"{get_the_other_host(0)}:8002"
with RemoteDisaggOpenAIServer(ctx_servers=[ctx_url],
gen_servers=[gen_url],
port=DISAGG_SERVER_PORT,
llmapi_launch=True,
env=env()) as server:
yield server
else:
print(f"skipping disagg_server for rank {RANK} node rank {NODE_RANK}")
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
wait_for_endpoint_ready(url, 60)
yield None


@pytest.fixture(scope="module")
def client(disagg_server: RemoteDisaggOpenAIServer):
if is_pytest_node():
return disagg_server.get_client()
else:
print(f"skipping client for rank {RANK} node rank {NODE_RANK}")
return None


def test_completion(client: openai.OpenAI,
disagg_server: RemoteDisaggOpenAIServer, model_name: str):
if is_pytest_node():
print(f"running test_completion on rank {RANK} node rank {NODE_RANK}")
prompt = "What is the result of 1+1? Answer in one word: "
completion = client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=10,
temperature=0.0,
)
print(f"Output: {completion.choices[0].text}")
assert completion.id is not None
message = completion.choices[0].text
assert message.startswith('2.')
disagg_server.terminate()

elif is_gen_node():
# keep gen workers alive until the test ends, again we hope the NODE_LIST is ordered by NODE_RANK
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
wait_for_endpoint_down(url, 60)
assert True
else:
assert True
88 changes: 83 additions & 5 deletions tests/unittest/llmapi/apps/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import os
import subprocess
import sys
import tempfile
import time
from typing import List
from typing import List, Optional

import openai
import requests
import yaml

from tensorrt_llm.llmapi.mpi_session import find_free_port

Expand All @@ -20,19 +22,33 @@ def __init__(self,
model: str,
cli_args: List[str] = None,
llmapi_launch: bool = False,
port: int = None) -> None:
self.host = "localhost"
port: int = None,
host: str = "localhost",
env: Optional[dict] = None,
rank: int = -1,
extra_config: Optional[dict] = None) -> None:
self.host = host
self.port = port if port is not None else find_free_port()
self.rank = os.environ.get("SLURM_PROCID", 0)

self.rank = rank if rank != -1 else os.environ.get("SLURM_PROCID", 0)
self.extra_config_file = None
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
if cli_args:
args += cli_args
if extra_config:
with tempfile.NamedTemporaryFile(mode="w+",
delete=False,
delete_on_close=False) as f:
f.write(yaml.dump(extra_config))
self.extra_config_file = f.name
args += ["--extra_llm_api_options", self.extra_config_file]
launch_cmd = ["trtllm-serve"] + [model] + args
if llmapi_launch:
# start server with llmapi-launch on multi nodes
launch_cmd = ["trtllm-llmapi-launch"] + launch_cmd
if not env:
env = os.environ.copy()
self.proc = subprocess.Popen(launch_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
Expand All @@ -42,12 +58,23 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.terminate()

def terminate(self):
if self.proc is None:
return
self.proc.terminate()
try:
self.proc.wait(timeout=30)
except subprocess.TimeoutExpired as e:
self.proc.kill()
self.proc.wait(timeout=30)
try:
if self.extra_config_file:
os.remove(self.extra_config_file)
except Exception as e:
print(f"Error removing extra config file: {e}")
self.proc = None

def _wait_for_server(self, *, url: str, timeout: float):
# run health check on the first rank only.
Expand All @@ -57,6 +84,8 @@ def _wait_for_server(self, *, url: str, timeout: float):
if self.rank == 0:
if requests.get(url).status_code == 200:
break
else:
time.sleep(0.5)
else:
time.sleep(timeout)
break
Expand Down Expand Up @@ -87,3 +116,52 @@ def get_async_client(self, **kwargs):
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
**kwargs)


class RemoteDisaggOpenAIServer(RemoteOpenAIServer):

def __init__(self,
ctx_servers: List[str],
gen_servers: List[str],
port: int = -1,
env: Optional[dict] = None,
llmapi_launch: bool = False) -> None:
self.ctx_servers = ctx_servers
self.gen_servers = gen_servers
self.host = "localhost"
self.port = find_free_port() if port is None or port < 0 else port
self.rank = 0
with tempfile.NamedTemporaryFile(mode="w+",
delete=False,
delete_on_close=False) as f:
f.write(self._get_extra_config())
f.flush()
self.extra_config_file = f.name
launch_cmd = [
"trtllm-serve", "disaggregated", "-c", self.extra_config_file
]
if llmapi_launch:
# start server with llmapi-launch on multi nodes
launch_cmd = ["trtllm-llmapi-launch"] + launch_cmd
if not env:
env = os.environ.copy()
self.proc = subprocess.Popen(launch_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)

def _get_extra_config(self):
return yaml.dump({
"context_servers": {
"num_instances": len(self.ctx_servers),
"urls": self.ctx_servers
},
"generation_servers": {
"num_instances": len(self.gen_servers),
"urls": self.gen_servers
},
"port": self.port,
"hostname": self.host,
})
Loading