@@ -115,100 +115,139 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
115115 ctx_total_gpus = ctx_tp * ctx_pp
116116 gen_total_gpus = gen_tp * gen_pp
117117
118- env_ctx = os .environ .copy ()
119- env_ctx ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
120- env_ctx ["CUDA_VISIBLE_DEVICES" ] = "," .join (map (str , range (ctx_total_gpus )))
121-
122- env_gen = os .environ .copy ()
123- env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
124- env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
125- map (str , range (ctx_total_gpus , ctx_total_gpus + gen_total_gpus )))
126- ctx_server_args = common_args + [
127- "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path ,
128- f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } "
129- ]
130- gen_server_args = common_args + [
131- "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path ,
132- f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } "
133- ]
134- if "max_num_tokens" in ctx_server_config :
135- ctx_server_args .append (
136- f"--max_num_tokens={ ctx_server_config ['max_num_tokens' ]} " )
137- if "max_num_tokens" in gen_server_config :
138- gen_server_args .append (
139- f"--max_num_tokens={ gen_server_config ['max_num_tokens' ]} " )
140-
141- with (MyThreadPoolExecutor (max_workers = 16 ) as
142- thread_pool , temp_dir , popen (ctx_server_args , env = env_ctx ) as
143- ctx_server , popen (gen_server_args , env = env_gen ) as gen_server ,
144- popen ([
145- trtllm_serve_path , "disaggregated" , "-c" ,
146- disaggregated_serving_config_path , "--server_start_timeout" ,
147- "3600"
148- ]) as disaggregated_server ):
149- while True :
150- time .sleep (1 )
151- try :
152- print ("Checking health endpoint" )
153- response = requests .get ("http://localhost:8000/health" )
154- if response .status_code == 200 :
155- break
156- except requests .exceptions .ConnectionError :
157- continue
158-
159- client = openai .OpenAI (api_key = "1234567890" ,
160- base_url = f"http://localhost:8000/v1" )
161-
162- def send_request (prompt : str , sampling_params : SamplingParams ,
163- streaming : bool ):
164- response = client .completions .create (
165- model = model_name ,
166- prompt = prompt ,
167- stream = streaming ,
168- ** ({
169- "max_tokens" : sampling_params .max_tokens ,
170- "temperature" : sampling_params .temperature ,
171- "top_p" : sampling_params .top_p ,
172- "stop" : sampling_params .stop ,
173- "seed" : sampling_params .seed
174- } if sampling_params else {}))
175- result = Result (id = 0 ,
176- sampling_params = sampling_params ,
177- outputs = [
178- CompletionOutput (text = response .choices [0 ].text ,
179- index = 0 )
180- ])
181- requested_output = RequestOutput ._from_generation_result (
182- result , prompt = prompt )
183- setattr (requested_output , "result" , result .result )
184- return requested_output
185-
186- def generate_async (prompt : str ,
187- sampling_params : Optional [SamplingParams ] = None ,
188- streaming : bool = False ):
189- future = thread_pool .submit (send_request , prompt , sampling_params ,
190- streaming )
191- thread_pool .futures .append (future )
192- return future
193-
118+ ctx_urls = disaggregated_server_config ["context_servers" ]["urls" ]
119+ gen_urls = disaggregated_server_config ["generation_servers" ]["urls" ]
120+
121+ ctx_ports = [int (url .split (":" )[1 ]) for url in ctx_urls ]
122+ gen_ports = [int (url .split (":" )[1 ]) for url in gen_urls ]
123+
124+ ctx_servers = []
125+ current_gpu_offset = 0
126+
127+ for i , port in enumerate (ctx_ports ):
128+ env_ctx = os .environ .copy ()
129+ env_ctx ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
130+ gpu_range = range (current_gpu_offset ,
131+ current_gpu_offset + ctx_total_gpus )
132+ env_ctx ["CUDA_VISIBLE_DEVICES" ] = "," .join (map (str , gpu_range ))
133+ current_gpu_offset += ctx_total_gpus
134+
135+ ctx_server_args = common_args + [
136+ "--port" ,
137+ str (port ), "--extra_llm_api_options" , ctx_server_config_path ,
138+ f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } "
139+ ]
140+ if "max_num_tokens" in ctx_server_config :
141+ ctx_server_args .append (
142+ f"--max_num_tokens={ ctx_server_config ['max_num_tokens' ]} " )
143+
144+ ctx_servers .append ((env_ctx , ctx_server_args ))
145+
146+ gen_servers = []
147+
148+ for i , port in enumerate (gen_ports ):
149+ env_gen = os .environ .copy ()
150+ env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
151+ gpu_range = range (current_gpu_offset ,
152+ current_gpu_offset + gen_total_gpus )
153+ env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (map (str , gpu_range ))
154+ current_gpu_offset += gen_total_gpus
155+
156+ gen_server_args = common_args + [
157+ "--port" ,
158+ str (port ), "--extra_llm_api_options" , gen_server_config_path ,
159+ f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } "
160+ ]
161+ if "max_num_tokens" in gen_server_config :
162+ gen_server_args .append (
163+ f"--max_num_tokens={ gen_server_config ['max_num_tokens' ]} " )
164+
165+ gen_servers .append ((env_gen , gen_server_args ))
166+
167+ @contextlib .contextmanager
168+ def multi_popen (server_configs ):
169+ processes = []
194170 try :
195- yield DuckLLM (args , generate_async )
196- finally :
197- ctx_server .terminate ()
198- gen_server .terminate ()
199- disaggregated_server .terminate ()
200-
201- ctx_server .wait ()
202- gen_server .wait ()
203- disaggregated_server .wait ()
171+ for env , args in server_configs :
172+ proc = popen (args , env = env )
173+ processes .append (proc )
174+
175+ with contextlib .ExitStack () as stack :
176+ opened_processes = [
177+ stack .enter_context (proc ) for proc in processes
178+ ]
179+ yield opened_processes
180+ except Exception as e :
181+ print (
182+ f"Failed to start disaggregated server processes in multi_popen: { e } "
183+ )
184+ raise
185+
186+ with (MyThreadPoolExecutor (max_workers = 16 ) as thread_pool , temp_dir ):
187+ with multi_popen (ctx_servers + gen_servers ):
188+ with popen ([
189+ trtllm_serve_path , "disaggregated" , "-c" ,
190+ disaggregated_serving_config_path , "--server_start_timeout" ,
191+ "3600"
192+ ]):
193+ while True :
194+ time .sleep (1 )
195+ try :
196+ print ("Checking health endpoint" )
197+ response = requests .get ("http://localhost:8000/health" )
198+ if response .status_code == 200 :
199+ break
200+ except requests .exceptions .ConnectionError :
201+ continue
202+
203+ client = openai .OpenAI (api_key = "1234567890" ,
204+ base_url = f"http://localhost:8000/v1" )
205+
206+ def send_request (prompt : str , sampling_params : SamplingParams ,
207+ streaming : bool ):
208+ response = client .completions .create (
209+ model = model_name ,
210+ prompt = prompt ,
211+ stream = streaming ,
212+ ** ({
213+ "max_tokens" : sampling_params .max_tokens ,
214+ "temperature" : sampling_params .temperature ,
215+ "top_p" : sampling_params .top_p ,
216+ "stop" : sampling_params .stop ,
217+ "seed" : sampling_params .seed
218+ } if sampling_params else {}))
219+ result = Result (id = 0 ,
220+ sampling_params = sampling_params ,
221+ outputs = [
222+ CompletionOutput (
223+ text = response .choices [0 ].text ,
224+ index = 0 )
225+ ])
226+ requested_output = RequestOutput ._from_generation_result (
227+ result , prompt = prompt )
228+ setattr (requested_output , "result" , result .result )
229+ return requested_output
230+
231+ def generate_async (
232+ prompt : str ,
233+ sampling_params : Optional [SamplingParams ] = None ,
234+ streaming : bool = False ):
235+ future = thread_pool .submit (send_request , prompt ,
236+ sampling_params , streaming )
237+ thread_pool .futures .append (future )
238+ return future
239+
240+ yield DuckLLM (args , generate_async )
204241
205242
206243def run_parallel_test (model_name : str , model_path : str , ctx_pp : int ,
207- ctx_tp : int , gen_pp : int , gen_tp : int ,
208- test_set : LlmapiAccuracyTestHarness ):
209- if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count ():
244+ ctx_tp : int , gen_pp : int , gen_tp : int , ctx_instances : int ,
245+ gen_instances : int , test_set : LlmapiAccuracyTestHarness ):
246+ total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
247+ total_gen_gpus = gen_tp * gen_pp * gen_instances
248+ if total_ctx_gpus + total_gen_gpus > get_device_count ():
210249 pytest .fail (
211- f"Not enough devices for ctx_pp={ ctx_pp } + ctx_tp={ ctx_tp } and gen_pp={ gen_pp } + gen_tp={ gen_tp } test "
250+ f"Not enough devices for { ctx_instances } ctx instances ( ctx_pp={ ctx_pp } * ctx_tp={ ctx_tp } ) + { gen_instances } gen instances ( gen_pp={ gen_pp } * gen_tp={ gen_tp } ), total: { total_ctx_gpus + total_gen_gpus } "
212251 )
213252
214253 kv_cache_config = {
@@ -233,17 +272,21 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
233272 "backend" : "default"
234273 }
235274 }
275+
276+ ctx_urls = [f"localhost:{ 8001 + i * 2 } " for i in range (ctx_instances )]
277+ gen_urls = [f"localhost:{ 8002 + i * 2 } " for i in range (gen_instances )]
278+
236279 disaggregated_server_config = {
237280 "hostname" : "localhost" ,
238281 "port" : 8000 ,
239282 "backend" : "pytorch" ,
240283 "context_servers" : {
241- "num_instances" : 1 ,
242- "urls" : [ "localhost:8001" ]
284+ "num_instances" : ctx_instances ,
285+ "urls" : ctx_urls
243286 },
244287 "generation_servers" : {
245- "num_instances" : 1 ,
246- "urls" : [ "localhost:8002" ]
288+ "num_instances" : gen_instances ,
289+ "urls" : gen_urls
247290 }
248291 }
249292 with launch_disaggregated_llm (disaggregated_server_config ,
@@ -399,14 +442,21 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
399442 @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
400443 def test_tp_pp_symmetric (self , tp , pp , testset ):
401444 return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , pp , tp , pp ,
402- tp , get_accuracy_task (testset ))
445+ tp , 1 , 1 , get_accuracy_task (testset ))
403446
447+ @pytest .mark .skip_less_device (4 )
404448 @parametrize_with_ids ("ctx_pp" , [2 , 4 ])
405449 @parametrize_with_ids ("gen_tp" , [1 , 2 ])
406450 @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
407451 def test_ctx_pp_gen_tp_asymmetric (self , ctx_pp , gen_tp , testset ):
408452 return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , ctx_pp , 1 , 1 ,
409- gen_tp , get_accuracy_task (testset ))
453+ gen_tp , 1 , 1 , get_accuracy_task (testset ))
454+
455+ @pytest .mark .skip_less_device (4 )
456+ @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
457+ def test_multi_instance (self , testset ):
458+ return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , 1 , 1 , 1 , 1 ,
459+ 2 , 2 , get_accuracy_task (testset ))
410460
411461
412462@pytest .mark .skip_less_device_memory (140000 )
0 commit comments