Skip to content

Commit ddf8e8d

Browse files
authored
[None][feat] adding support for disaggregated multi-instance tests (#6674)
Signed-off-by: raayandhar <[email protected]>
1 parent 64c8788 commit ddf8e8d

File tree

3 files changed

+149
-95
lines changed

3 files changed

+149
-95
lines changed

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 145 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -115,100 +115,139 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
115115
ctx_total_gpus = ctx_tp * ctx_pp
116116
gen_total_gpus = gen_tp * gen_pp
117117

118-
env_ctx = os.environ.copy()
119-
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
120-
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))
121-
122-
env_gen = os.environ.copy()
123-
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
124-
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
125-
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
126-
ctx_server_args = common_args + [
127-
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
128-
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
129-
]
130-
gen_server_args = common_args + [
131-
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
132-
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
133-
]
134-
if "max_num_tokens" in ctx_server_config:
135-
ctx_server_args.append(
136-
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
137-
if "max_num_tokens" in gen_server_config:
138-
gen_server_args.append(
139-
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
140-
141-
with (MyThreadPoolExecutor(max_workers=16) as
142-
thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as
143-
ctx_server, popen(gen_server_args, env=env_gen) as gen_server,
144-
popen([
145-
trtllm_serve_path, "disaggregated", "-c",
146-
disaggregated_serving_config_path, "--server_start_timeout",
147-
"3600"
148-
]) as disaggregated_server):
149-
while True:
150-
time.sleep(1)
151-
try:
152-
print("Checking health endpoint")
153-
response = requests.get("http://localhost:8000/health")
154-
if response.status_code == 200:
155-
break
156-
except requests.exceptions.ConnectionError:
157-
continue
158-
159-
client = openai.OpenAI(api_key="1234567890",
160-
base_url=f"http://localhost:8000/v1")
161-
162-
def send_request(prompt: str, sampling_params: SamplingParams,
163-
streaming: bool):
164-
response = client.completions.create(
165-
model=model_name,
166-
prompt=prompt,
167-
stream=streaming,
168-
**({
169-
"max_tokens": sampling_params.max_tokens,
170-
"temperature": sampling_params.temperature,
171-
"top_p": sampling_params.top_p,
172-
"stop": sampling_params.stop,
173-
"seed": sampling_params.seed
174-
} if sampling_params else {}))
175-
result = Result(id=0,
176-
sampling_params=sampling_params,
177-
outputs=[
178-
CompletionOutput(text=response.choices[0].text,
179-
index=0)
180-
])
181-
requested_output = RequestOutput._from_generation_result(
182-
result, prompt=prompt)
183-
setattr(requested_output, "result", result.result)
184-
return requested_output
185-
186-
def generate_async(prompt: str,
187-
sampling_params: Optional[SamplingParams] = None,
188-
streaming: bool = False):
189-
future = thread_pool.submit(send_request, prompt, sampling_params,
190-
streaming)
191-
thread_pool.futures.append(future)
192-
return future
193-
118+
ctx_urls = disaggregated_server_config["context_servers"]["urls"]
119+
gen_urls = disaggregated_server_config["generation_servers"]["urls"]
120+
121+
ctx_ports = [int(url.split(":")[1]) for url in ctx_urls]
122+
gen_ports = [int(url.split(":")[1]) for url in gen_urls]
123+
124+
ctx_servers = []
125+
current_gpu_offset = 0
126+
127+
for i, port in enumerate(ctx_ports):
128+
env_ctx = os.environ.copy()
129+
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
130+
gpu_range = range(current_gpu_offset,
131+
current_gpu_offset + ctx_total_gpus)
132+
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range))
133+
current_gpu_offset += ctx_total_gpus
134+
135+
ctx_server_args = common_args + [
136+
"--port",
137+
str(port), "--extra_llm_api_options", ctx_server_config_path,
138+
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
139+
]
140+
if "max_num_tokens" in ctx_server_config:
141+
ctx_server_args.append(
142+
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
143+
144+
ctx_servers.append((env_ctx, ctx_server_args))
145+
146+
gen_servers = []
147+
148+
for i, port in enumerate(gen_ports):
149+
env_gen = os.environ.copy()
150+
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
151+
gpu_range = range(current_gpu_offset,
152+
current_gpu_offset + gen_total_gpus)
153+
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range))
154+
current_gpu_offset += gen_total_gpus
155+
156+
gen_server_args = common_args + [
157+
"--port",
158+
str(port), "--extra_llm_api_options", gen_server_config_path,
159+
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
160+
]
161+
if "max_num_tokens" in gen_server_config:
162+
gen_server_args.append(
163+
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
164+
165+
gen_servers.append((env_gen, gen_server_args))
166+
167+
@contextlib.contextmanager
168+
def multi_popen(server_configs):
169+
processes = []
194170
try:
195-
yield DuckLLM(args, generate_async)
196-
finally:
197-
ctx_server.terminate()
198-
gen_server.terminate()
199-
disaggregated_server.terminate()
200-
201-
ctx_server.wait()
202-
gen_server.wait()
203-
disaggregated_server.wait()
171+
for env, args in server_configs:
172+
proc = popen(args, env=env)
173+
processes.append(proc)
174+
175+
with contextlib.ExitStack() as stack:
176+
opened_processes = [
177+
stack.enter_context(proc) for proc in processes
178+
]
179+
yield opened_processes
180+
except Exception as e:
181+
print(
182+
f"Failed to start disaggregated server processes in multi_popen: {e}"
183+
)
184+
raise
185+
186+
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir):
187+
with multi_popen(ctx_servers + gen_servers):
188+
with popen([
189+
trtllm_serve_path, "disaggregated", "-c",
190+
disaggregated_serving_config_path, "--server_start_timeout",
191+
"3600"
192+
]):
193+
while True:
194+
time.sleep(1)
195+
try:
196+
print("Checking health endpoint")
197+
response = requests.get("http://localhost:8000/health")
198+
if response.status_code == 200:
199+
break
200+
except requests.exceptions.ConnectionError:
201+
continue
202+
203+
client = openai.OpenAI(api_key="1234567890",
204+
base_url=f"http://localhost:8000/v1")
205+
206+
def send_request(prompt: str, sampling_params: SamplingParams,
207+
streaming: bool):
208+
response = client.completions.create(
209+
model=model_name,
210+
prompt=prompt,
211+
stream=streaming,
212+
**({
213+
"max_tokens": sampling_params.max_tokens,
214+
"temperature": sampling_params.temperature,
215+
"top_p": sampling_params.top_p,
216+
"stop": sampling_params.stop,
217+
"seed": sampling_params.seed
218+
} if sampling_params else {}))
219+
result = Result(id=0,
220+
sampling_params=sampling_params,
221+
outputs=[
222+
CompletionOutput(
223+
text=response.choices[0].text,
224+
index=0)
225+
])
226+
requested_output = RequestOutput._from_generation_result(
227+
result, prompt=prompt)
228+
setattr(requested_output, "result", result.result)
229+
return requested_output
230+
231+
def generate_async(
232+
prompt: str,
233+
sampling_params: Optional[SamplingParams] = None,
234+
streaming: bool = False):
235+
future = thread_pool.submit(send_request, prompt,
236+
sampling_params, streaming)
237+
thread_pool.futures.append(future)
238+
return future
239+
240+
yield DuckLLM(args, generate_async)
204241

205242

206243
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
207-
ctx_tp: int, gen_pp: int, gen_tp: int,
208-
test_set: LlmapiAccuracyTestHarness):
209-
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
244+
ctx_tp: int, gen_pp: int, gen_tp: int, ctx_instances: int,
245+
gen_instances: int, test_set: LlmapiAccuracyTestHarness):
246+
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
247+
total_gen_gpus = gen_tp * gen_pp * gen_instances
248+
if total_ctx_gpus + total_gen_gpus > get_device_count():
210249
pytest.fail(
211-
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
250+
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
212251
)
213252

214253
kv_cache_config = {
@@ -233,17 +272,21 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
233272
"backend": "default"
234273
}
235274
}
275+
276+
ctx_urls = [f"localhost:{8001 + i * 2}" for i in range(ctx_instances)]
277+
gen_urls = [f"localhost:{8002 + i * 2}" for i in range(gen_instances)]
278+
236279
disaggregated_server_config = {
237280
"hostname": "localhost",
238281
"port": 8000,
239282
"backend": "pytorch",
240283
"context_servers": {
241-
"num_instances": 1,
242-
"urls": ["localhost:8001"]
284+
"num_instances": ctx_instances,
285+
"urls": ctx_urls
243286
},
244287
"generation_servers": {
245-
"num_instances": 1,
246-
"urls": ["localhost:8002"]
288+
"num_instances": gen_instances,
289+
"urls": gen_urls
247290
}
248291
}
249292
with launch_disaggregated_llm(disaggregated_server_config,
@@ -399,14 +442,21 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
399442
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
400443
def test_tp_pp_symmetric(self, tp, pp, testset):
401444
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
402-
tp, get_accuracy_task(testset))
445+
tp, 1, 1, get_accuracy_task(testset))
403446

447+
@pytest.mark.skip_less_device(4)
404448
@parametrize_with_ids("ctx_pp", [2, 4])
405449
@parametrize_with_ids("gen_tp", [1, 2])
406450
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
407451
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
408452
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
409-
gen_tp, get_accuracy_task(testset))
453+
gen_tp, 1, 1, get_accuracy_task(testset))
454+
455+
@pytest.mark.skip_less_device(4)
456+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
457+
def test_multi_instance(self, testset):
458+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
459+
2, 2, get_accuracy_task(testset))
410460

411461

412462
@pytest.mark.skip_less_device_memory(140000)

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen
527527
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4]
528528
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
529529
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
530+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
531+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
530532
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
531533
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
532534
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ l0_dgx_h100:
5050
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]
5151
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
5252
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
53+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
54+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
5355
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
5456
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
5557
- test_e2e.py::test_ptp_quickstart_advanced_bs1

0 commit comments

Comments
 (0)