Skip to content

Commit bfcb1e0

Browse files
committed
improve gen worker's waiting
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 8d7b94b commit bfcb1e0

File tree

3 files changed

+64
-40
lines changed

3 files changed

+64
-40
lines changed

tests/unittest/llmapi/apps/_test_disagg_serving_multi_nodes.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ def find_nic():
5757
print(f"test_ip: {test_ip} for the other host {get_the_other_host()}")
5858
try:
5959
# iproute2 may not be installed
60-
result = subprocess.check_output(
61-
f"ip route get {test_ip} | sed -E 's/.*?dev (\\S+) .*/\\1/;t;d'",
62-
shell=True)
63-
nic_name = result.decode('utf-8').strip()
60+
proc = subprocess.run(f"ip route get {test_ip}",
61+
capture_output=True,
62+
text=True,
63+
shell=True,
64+
check=True)
65+
nic_name = proc.stdout.split()[4]
6466
print(f"get NIC name from ip route, result: {nic_name}")
6567
return nic_name
6668
except Exception as e:
@@ -69,7 +71,8 @@ def find_nic():
6971
# Establish a socket to the test ip, then get the local ip from the socket,
7072
# enumerate the local interfaces and find the one with the local ip
7173
local_ip = get_local_ip(test_ip)
72-
for nic_name, ip in get_local_interfaces().items():
74+
local_ip_dict = get_local_interfaces()
75+
for nic_name, ip in local_ip_dict.items():
7376
if ip == local_ip:
7477
return nic_name
7578
except OSError as e:
@@ -89,7 +92,10 @@ def env():
8992
if nic:
9093
# TODO: integrate this into disagg-serving
9194
# setting TRTLLM_UCX_INTERFACE manually if possible because the interfaces found automatically by TRTLLM can have the same ip across nodes, then cache transceiver may fail to send/receive kv cache
95+
print(f"setting TRTLLM_UCX_INTERFACE to {nic}")
9296
new_env["TRTLLM_UCX_INTERFACE"] = nic
97+
else:
98+
print(f"Failed to find NIC, will use default UCX interface")
9399
return new_env
94100

95101

@@ -164,15 +170,40 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
164170
yield None
165171

166172

173+
def wait_for_endpoint_ready(url: str, timeout: int = 300):
174+
start = time.monotonic()
175+
while time.monotonic() - start < timeout:
176+
try:
177+
time.sleep(1)
178+
if requests.get(url).status_code == 200:
179+
print(f"endpoint {url} is ready")
180+
return
181+
except Exception as err:
182+
print(f"endpoint {url} is not ready, with exception: {err}")
183+
184+
185+
def wait_for_endpoint_down(url: str, timeout: int = 300):
186+
start = time.monotonic()
187+
while time.monotonic() - start < timeout:
188+
try:
189+
if requests.get(url).status_code >= 100:
190+
print(
191+
f"endpoint {url} returned status code {requests.get(url).status_code}"
192+
)
193+
time.sleep(1)
194+
except Exception as err:
195+
print(f"endpoint {url} is down, with exception: {err}")
196+
return
197+
198+
167199
@pytest.fixture(scope="module")
168200
def disagg_server(worker: RemoteOpenAIServer):
169201
if is_disagg_node():
170202
print(f"starting disagg_server for rank {RANK} node rank {NODE_RANK}")
171203
ctx_url = f"localhost:8001" # Use localhost since the ctx server is on the same node
172-
# TODO: Hopefully the NODE_LIST is ordered by NODE_RANK, this test is only tested with 2 nodes now
173-
# We need to test with 4 nodes or more
204+
# TODO: Hopefully the NODE_LIST is ordered by NODE_RANK, this test is only expected to run with 2 nodes now
205+
# We need to test with 4 nodes or more in the future, which should be easier with service discovery
174206
gen_url = f"{get_the_other_host(0)}:8002"
175-
print(f"ctx_url: {ctx_url} gen_url: {gen_url}")
176207
with RemoteDisaggOpenAIServer(ctx_servers=[ctx_url],
177208
gen_servers=[gen_url],
178209
port=DISAGG_SERVER_PORT,
@@ -181,6 +212,8 @@ def disagg_server(worker: RemoteOpenAIServer):
181212
yield server
182213
else:
183214
print(f"skipping disagg_server for rank {RANK} node rank {NODE_RANK}")
215+
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
216+
wait_for_endpoint_ready(url, 60)
184217
yield None
185218

186219

@@ -193,30 +226,8 @@ def client(disagg_server: RemoteDisaggOpenAIServer):
193226
return None
194227

195228

196-
def wait_for_endpoint_ready(url: str, timeout: int = 300):
197-
start = time.time()
198-
while time.time() - start < timeout:
199-
try:
200-
time.sleep(1)
201-
if requests.get(url).status_code == 200:
202-
print(f"endpoint {url} is ready")
203-
return
204-
except Exception:
205-
pass
206-
207-
208-
def wait_for_endpoint_down(url: str, timeout: int = 300):
209-
start = time.time()
210-
while time.time() - start < timeout:
211-
try:
212-
if requests.get(url).status_code >= 100:
213-
time.sleep(1)
214-
except Exception as err:
215-
print(f"endpoint {url} is down, with exception: {err}")
216-
return
217-
218-
219-
def test_completion(client: openai.OpenAI, model_name: str):
229+
def test_completion(client: openai.OpenAI,
230+
disagg_server: RemoteDisaggOpenAIServer, model_name: str):
220231
if is_pytest_node():
221232
print(f"running test_completion on rank {RANK} node rank {NODE_RANK}")
222233
prompt = "What is the result of 1+1? Answer in one word: "
@@ -226,16 +237,16 @@ def test_completion(client: openai.OpenAI, model_name: str):
226237
max_tokens=10,
227238
temperature=0.0,
228239
)
229-
print(f"Completion: {completion}")
230240
print(f"Output: {completion.choices[0].text}")
231241
assert completion.id is not None
232242
message = completion.choices[0].text
233243
assert message.startswith('2.')
244+
disagg_server.terminate()
245+
234246
elif is_gen_node():
235247
# keep gen workers alive until the test ends, again we hope the NODE_LIST is ordered by NODE_RANK
236248
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
237-
wait_for_endpoint_ready(url)
238-
wait_for_endpoint_down(url)
249+
wait_for_endpoint_down(url, 60)
239250
assert True
240251
else:
241252
assert True

tests/unittest/llmapi/apps/openai_server.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self,
3030
self.host = host
3131
self.port = port if port is not None else find_free_port()
3232
self.rank = rank if rank != -1 else os.environ.get("SLURM_PROCID", 0)
33+
self.extra_config_file = None
3334
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
3435
if cli_args:
3536
args += cli_args
@@ -57,12 +58,23 @@ def __enter__(self):
5758
return self
5859

5960
def __exit__(self, exc_type, exc_value, traceback):
61+
self.terminate()
62+
63+
def terminate(self):
64+
if self.proc is None:
65+
return
6066
self.proc.terminate()
6167
try:
6268
self.proc.wait(timeout=30)
6369
except subprocess.TimeoutExpired as e:
6470
self.proc.kill()
6571
self.proc.wait(timeout=30)
72+
try:
73+
if self.extra_config_file:
74+
os.remove(self.extra_config_file)
75+
except Exception as e:
76+
print(f"Error removing extra config file: {e}")
77+
self.proc = None
6678

6779
def _wait_for_server(self, *, url: str, timeout: float):
6880
# run health check on the first rank only.
@@ -119,13 +131,15 @@ def __init__(self,
119131
self.host = "localhost"
120132
self.port = port if port is not None else find_free_port()
121133
self.rank = 0 # rank is always 0 since there is only one disagg server
122-
self.config = self._get_extra_config()
123134
with tempfile.NamedTemporaryFile(mode="w+",
124135
delete=False,
125136
delete_on_close=False) as f:
126-
self.config_file = f.name
127-
f.write(self.config)
128-
launch_cmd = ["trtllm-serve", "disaggregated", "-c", self.config_file]
137+
f.write(self._get_extra_config())
138+
f.flush()
139+
self.extra_config_file = f.name
140+
launch_cmd = [
141+
"trtllm-serve", "disaggregated", "-c", self.extra_config_file
142+
]
129143
if llmapi_launch:
130144
# start server with llmapi-launch on multi nodes
131145
launch_cmd = ["trtllm-llmapi-launch"] + launch_cmd

tests/unittest/llmapi/apps/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def expand_slurm_nodelist(nodelist_str):
258258
for group in groups:
259259
# Check if this group has bracket notation
260260
bracket_match = re.match(r'^(.+?)\[(.+?)\]$', group)
261-
print(f"nodelist_str: {nodelist_str}, bracket_match: {bracket_match}")
262261
if bracket_match:
263262
prefix = bracket_match.group(1)
264263
range_part = bracket_match.group(2)

0 commit comments

Comments
 (0)