Skip to content

Commit d8556ef

Browse files
committed
minor nitpicks + make it easier to add multi-instance tests
Signed-off-by: raayandhar <[email protected]>
1 parent cbf4bde commit d8556ef

File tree

3 files changed

+32
-59
lines changed

3 files changed

+32
-59
lines changed

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 28 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,19 @@ def multi_popen(server_configs):
177177
stack.enter_context(proc) for proc in processes
178178
]
179179
yield opened_processes
180-
finally:
181-
pass
180+
except Exception as e:
181+
logger.error(
182+
f"Failed to start disaggregated server processes in multi_popen: {e}"
183+
)
184+
raise
182185

183186
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir):
184-
with multi_popen(ctx_servers + gen_servers) as server_processes:
187+
with multi_popen(ctx_servers + gen_servers):
185188
with popen([
186189
trtllm_serve_path, "disaggregated", "-c",
187190
disaggregated_serving_config_path, "--server_start_timeout",
188191
"3600"
189-
]) as disaggregated_server:
192+
]):
190193
while True:
191194
time.sleep(1)
192195
try:
@@ -238,11 +241,13 @@ def generate_async(
238241

239242

240243
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
241-
ctx_tp: int, gen_pp: int, gen_tp: int,
242-
test_set: LlmapiAccuracyTestHarness):
243-
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
244+
ctx_tp: int, gen_pp: int, gen_tp: int, ctx_instances: int,
245+
gen_instances: int, test_set: LlmapiAccuracyTestHarness):
246+
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
247+
total_gen_gpus = gen_tp * gen_pp * gen_instances
248+
if total_ctx_gpus + total_gen_gpus > get_device_count():
244249
pytest.fail(
245-
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
250+
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
246251
)
247252

248253
kv_cache_config = {
@@ -267,17 +272,21 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
267272
"backend": "default"
268273
}
269274
}
275+
276+
ctx_urls = [f"localhost:{8001 + i * 2}" for i in range(ctx_instances)]
277+
gen_urls = [f"localhost:{8002 + i * 2}" for i in range(gen_instances)]
278+
270279
disaggregated_server_config = {
271280
"hostname": "localhost",
272281
"port": 8000,
273282
"backend": "pytorch",
274283
"context_servers": {
275-
"num_instances": 1,
276-
"urls": ["localhost:8001"]
284+
"num_instances": ctx_instances,
285+
"urls": ctx_urls
277286
},
278287
"generation_servers": {
279-
"num_instances": 1,
280-
"urls": ["localhost:8002"]
288+
"num_instances": gen_instances,
289+
"urls": gen_urls
281290
}
282291
}
283292
with launch_disaggregated_llm(disaggregated_server_config,
@@ -433,59 +442,21 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
433442
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
434443
def test_tp_pp_symmetric(self, tp, pp, testset):
435444
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
436-
tp, get_accuracy_task(testset))
445+
tp, 1, 1, get_accuracy_task(testset))
437446

447+
@pytest.mark.skip_less_device(4)
438448
@parametrize_with_ids("ctx_pp", [2, 4])
439449
@parametrize_with_ids("gen_tp", [1, 2])
440450
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
441451
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
442452
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
443-
gen_tp, get_accuracy_task(testset))
453+
gen_tp, 1, 1, get_accuracy_task(testset))
444454

445455
@pytest.mark.skip_less_device(4)
446-
def test_multi_instance(self):
447-
kv_cache_config = {
448-
"free_gpu_memory_fraction": 0.5,
449-
"enable_block_reuse": False
450-
}
451-
ctx_server_config = {
452-
"pipeline_parallel_size": 1,
453-
"tensor_parallel_size": 1,
454-
"disable_overlap_scheduler": True,
455-
"kv_cache_config": kv_cache_config,
456-
"cache_transceiver_config": {
457-
"backend": "default"
458-
}
459-
}
460-
gen_server_config = {
461-
"tensor_parallel_size": 1,
462-
"pipeline_parallel_size": 1,
463-
"disable_overlap_scheduler": True,
464-
"kv_cache_config": kv_cache_config,
465-
"cache_transceiver_config": {
466-
"backend": "default"
467-
}
468-
}
469-
disaggregated_server_config = {
470-
"hostname": "localhost",
471-
"port": 8000,
472-
"backend": "pytorch",
473-
"context_servers": {
474-
"num_instances": 2,
475-
"urls": ["localhost:8001", "localhost:8003"]
476-
},
477-
"generation_servers": {
478-
"num_instances": 2,
479-
"urls": ["localhost:8002", "localhost:8004"]
480-
}
481-
}
482-
with launch_disaggregated_llm(disaggregated_server_config,
483-
ctx_server_config, gen_server_config,
484-
self.MODEL_PATH) as llm:
485-
task = MMLU(self.MODEL_NAME)
486-
task.evaluate(llm)
487-
task = GSM8K(self.MODEL_NAME)
488-
task.evaluate(llm)
456+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
457+
def test_multi_instance(self, testset):
458+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
459+
2, 2, get_accuracy_task(testset))
489460

490461

491462
@pytest.mark.skip_less_device_memory(140000)

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen
526526
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4]
527527
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
528528
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
529-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance
529+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
530+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
530531
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
531532
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
532533
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ l0_dgx_h100:
5050
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]
5151
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
5252
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
53-
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance
53+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
54+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
5455
- test_e2e.py::test_ptp_quickstart_advanced_bs1
5556
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
5657
- condition:

0 commit comments

Comments
 (0)