@@ -177,16 +177,19 @@ def multi_popen(server_configs):
177177 stack .enter_context (proc ) for proc in processes
178178 ]
179179 yield opened_processes
180- finally :
181- pass
180+ except Exception as e :
181+ logger .error (
182+ f"Failed to start disaggregated server processes in multi_popen: { e } "
183+ )
184+ raise
182185
183186 with (MyThreadPoolExecutor (max_workers = 16 ) as thread_pool , temp_dir ):
184- with multi_popen (ctx_servers + gen_servers ) as server_processes :
187+ with multi_popen (ctx_servers + gen_servers ):
185188 with popen ([
186189 trtllm_serve_path , "disaggregated" , "-c" ,
187190 disaggregated_serving_config_path , "--server_start_timeout" ,
188191 "3600"
189- ]) as disaggregated_server :
192+ ]):
190193 while True :
191194 time .sleep (1 )
192195 try :
@@ -238,11 +241,13 @@ def generate_async(
238241
239242
240243def run_parallel_test (model_name : str , model_path : str , ctx_pp : int ,
241- ctx_tp : int , gen_pp : int , gen_tp : int ,
242- test_set : LlmapiAccuracyTestHarness ):
243- if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count ():
244+ ctx_tp : int , gen_pp : int , gen_tp : int , ctx_instances : int ,
245+ gen_instances : int , test_set : LlmapiAccuracyTestHarness ):
246+ total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
247+ total_gen_gpus = gen_tp * gen_pp * gen_instances
248+ if total_ctx_gpus + total_gen_gpus > get_device_count ():
244249 pytest .fail (
245- f"Not enough devices for ctx_pp={ ctx_pp } + ctx_tp={ ctx_tp } and gen_pp={ gen_pp } + gen_tp={ gen_tp } test "
250+ f"Not enough devices for { ctx_instances } ctx instances ( ctx_pp={ ctx_pp } * ctx_tp={ ctx_tp } ) + { gen_instances } gen instances ( gen_pp={ gen_pp } * gen_tp={ gen_tp } ), total: { total_ctx_gpus + total_gen_gpus } "
246251 )
247252
248253 kv_cache_config = {
@@ -267,17 +272,21 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
267272 "backend" : "default"
268273 }
269274 }
275+
276+ ctx_urls = [f"localhost:{ 8001 + i * 2 } " for i in range (ctx_instances )]
277+ gen_urls = [f"localhost:{ 8002 + i * 2 } " for i in range (gen_instances )]
278+
270279 disaggregated_server_config = {
271280 "hostname" : "localhost" ,
272281 "port" : 8000 ,
273282 "backend" : "pytorch" ,
274283 "context_servers" : {
275- "num_instances" : 1 ,
276- "urls" : [ "localhost:8001" ]
284+ "num_instances" : ctx_instances ,
285+ "urls" : ctx_urls
277286 },
278287 "generation_servers" : {
279- "num_instances" : 1 ,
280- "urls" : [ "localhost:8002" ]
288+ "num_instances" : gen_instances ,
289+ "urls" : gen_urls
281290 }
282291 }
283292 with launch_disaggregated_llm (disaggregated_server_config ,
@@ -433,59 +442,21 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
433442 @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
434443 def test_tp_pp_symmetric (self , tp , pp , testset ):
435444 return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , pp , tp , pp ,
436- tp , get_accuracy_task (testset ))
445+ tp , 1 , 1 , get_accuracy_task (testset ))
437446
447+ @pytest .mark .skip_less_device (4 )
438448 @parametrize_with_ids ("ctx_pp" , [2 , 4 ])
439449 @parametrize_with_ids ("gen_tp" , [1 , 2 ])
440450 @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
441451 def test_ctx_pp_gen_tp_asymmetric (self , ctx_pp , gen_tp , testset ):
442452 return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , ctx_pp , 1 , 1 ,
443- gen_tp , get_accuracy_task (testset ))
453+ gen_tp , 1 , 1 , get_accuracy_task (testset ))
444454
445455 @pytest .mark .skip_less_device (4 )
446- def test_multi_instance (self ):
447- kv_cache_config = {
448- "free_gpu_memory_fraction" : 0.5 ,
449- "enable_block_reuse" : False
450- }
451- ctx_server_config = {
452- "pipeline_parallel_size" : 1 ,
453- "tensor_parallel_size" : 1 ,
454- "disable_overlap_scheduler" : True ,
455- "kv_cache_config" : kv_cache_config ,
456- "cache_transceiver_config" : {
457- "backend" : "default"
458- }
459- }
460- gen_server_config = {
461- "tensor_parallel_size" : 1 ,
462- "pipeline_parallel_size" : 1 ,
463- "disable_overlap_scheduler" : True ,
464- "kv_cache_config" : kv_cache_config ,
465- "cache_transceiver_config" : {
466- "backend" : "default"
467- }
468- }
469- disaggregated_server_config = {
470- "hostname" : "localhost" ,
471- "port" : 8000 ,
472- "backend" : "pytorch" ,
473- "context_servers" : {
474- "num_instances" : 2 ,
475- "urls" : ["localhost:8001" , "localhost:8003" ]
476- },
477- "generation_servers" : {
478- "num_instances" : 2 ,
479- "urls" : ["localhost:8002" , "localhost:8004" ]
480- }
481- }
482- with launch_disaggregated_llm (disaggregated_server_config ,
483- ctx_server_config , gen_server_config ,
484- self .MODEL_PATH ) as llm :
485- task = MMLU (self .MODEL_NAME )
486- task .evaluate (llm )
487- task = GSM8K (self .MODEL_NAME )
488- task .evaluate (llm )
456+ @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
457+ def test_multi_instance (self , testset ):
458+ return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , 1 , 1 , 1 , 1 ,
459+ 2 , 2 , get_accuracy_task (testset ))
489460
490461
491462@pytest .mark .skip_less_device_memory (140000 )
0 commit comments