@@ -449,7 +449,9 @@ def __init__(self,
449449 skip_engine_build : bool = False ,
450450 quant : Optional [str ] = None ,
451451 extra_llm_api_options : Optional [str ] = None ,
452- use_mpirun : bool = False ):
452+ use_mpirun : bool = False ,
453+ concurrency : Optional [int ] = None ,
454+ num_requests : int = 10 ):
453455
454456 llm_models = llm_models_root ()
455457 assert llm_models is not None
@@ -474,12 +476,14 @@ def __init__(self,
474476 else :
475477 self .mpirun_cmd = ""
476478 self .engine_path = None
479+ self .concurrency = concurrency
480+ self .num_requests = num_requests
477481
478482 def __call__ (self ):
479483 self .prepare_dataset ()
480484 if not (self .skip_engine_build or self .use_pytorch_backend ):
481485 self .build_engine ()
482- self .run_bench ()
486+ return self .run_bench ()
483487
484488 def prepare_dataset (self ):
485489 dataset_tool = Path (self .llm_root , "benchmarks" , "cpp" ,
@@ -502,7 +506,7 @@ def prepare_dataset(self):
502506 "--output-stdev" ,
503507 "0" ,
504508 "--num-requests" ,
505- "10" ,
509+ str ( self . num_requests ) ,
506510 ]
507511 print (f"Running command: { ' ' .join (command )} " )
508512 dataset_output = self .llm_venv .run_cmd (
@@ -556,7 +560,43 @@ def run_bench(self):
556560
557561 if self .extra_llm_api_options :
558562 benchmark_cmd += f" --extra_llm_api_options { self .extra_llm_api_options } "
559- check_call (benchmark_cmd , shell = True , env = self .llm_venv ._new_env )
563+ if self .concurrency :
564+ benchmark_cmd += f" --concurrency { self .concurrency } "
565+ if self .num_requests :
566+ benchmark_cmd += f" --num_requests { self .num_requests } "
567+
568+ benchmark_output = check_output (benchmark_cmd ,
569+ shell = True ,
570+ env = self .llm_venv ._new_env )
571+ return self .parse_benchmark_output (benchmark_output )
572+
573+ def parse_benchmark_output (self , output ):
574+ """Parse the benchmark output to extract key metrics."""
575+ result = {
576+ 'concurrency' : self .concurrency ,
577+ 'num_requests' : self .num_requests ,
578+ 'throughput' : 0 ,
579+ 'latency' : 0
580+ }
581+
582+ lines = output .split ('\n ' )
583+ for line in lines :
584+ line = line .strip ()
585+ if 'total token throughput' in line .lower (
586+ ) and 'tokens/sec' in line .lower ():
587+ try :
588+ throughput = line .split (":" )[1 ].strip ()
589+ result ['throughput' ] = throughput
590+ except :
591+ pass
592+ elif 'total latency' in line .lower () and 'ms' in line .lower ():
593+ try :
594+ latency = line .split (":" )[1 ].strip ()
595+ result ['latency' ] = latency
596+ except :
597+ pass
598+
599+ return result
560600
561601
562602@pytest .mark .parametrize ("model_name" , ["meta-llama/Meta-Llama-3-8B-Instruct" ],
@@ -579,6 +619,61 @@ def test_trtllm_bench_llmapi_launch(llm_root, llm_venv, model_name,
579619 runner ()
580620
581621
622+ @skip_pre_hopper
623+ @pytest .mark .skip_less_device_memory (80000 )
624+ @pytest .mark .parametrize ("model_name" , ["meta/Meta-Llama-3.1-8B" ],
625+ ids = ["llama3_1-8b" ])
626+ @pytest .mark .parametrize ("model_subdir" , ["llama-3.1-model/Meta-Llama-3.1-8B" ],
627+ ids = ["llama_v3_1" ])
628+ @pytest .mark .parametrize ("use_pytorch_backend" , [False ], ids = ["trt_backend" ])
629+ def test_trtllm_bench_mig_launch (llm_root , llm_venv , model_name , model_subdir ,
630+ use_pytorch_backend ):
631+ "run bench mark in MIG mode, check if the throughput is increasing by concurrency"
632+ skip_engine_build = False
633+ results = {}
634+ concurrency_list = [1 , 32 , 64 , 128 ]
635+
636+ for concurrency in concurrency_list :
637+ num_requests = concurrency * 10
638+ runner = BenchRunner (llm_root = llm_root ,
639+ llm_venv = llm_venv ,
640+ model_name = model_name ,
641+ model_subdir = model_subdir ,
642+ streaming = False ,
643+ use_pytorch_backend = use_pytorch_backend ,
644+ use_mpirun = False ,
645+ tp_size = 1 ,
646+ concurrency = concurrency ,
647+ num_requests = num_requests ,
648+ skip_engine_build = skip_engine_build )
649+
650+ output = runner ()
651+ results [concurrency ] = output
652+
653+ print (f"\n === Benchmark Results Comparison ===" )
654+ print (f"Model: { model_name } " )
655+ print (f"Backend: { 'PyTorch' if use_pytorch_backend else 'TensorRT' } " )
656+ print (
657+ f"{ 'Concurrency' :<15} { 'Throughput' :<15} { 'Latency' :<15} { 'Num Requests' :<15} "
658+ )
659+ print ("-" * 60 )
660+
661+ for idx , val in enumerate (concurrency_list ):
662+ if hasattr (results [val ], 'get' ):
663+ throughput = float (results [val ].get ('throughput' , 0 ))
664+ latency = float (results [val ].get ('latency' , 0 ))
665+ num_requests = int (results [val ].get ('num_requests' , 0 ))
666+ assert throughput > 0 , f"Throughput is 0 for concurrency { concurrency } "
667+ assert latency > 0 , f"Latency is 0 for concurrency { concurrency } "
668+ print (
669+ f"{ concurrency :<15} { throughput :<15} { latency :<15} { num_requests :<15} "
670+ )
671+ if idx > 0 :
672+ assert throughput > float (
673+ results [concurrency_list [idx - 1 ]].get ('throughput' , 0 )
674+ ) * 1.3 , f"Throughput is not increasing for concurrency { concurrency_list [idx ]} "
675+
676+
582677@pytest .mark .parametrize (
583678 "model_name, llama_model_root" ,
584679 [pytest .param ("TinyLlama-1.1B-Chat-v1.0" , "TinyLlama-1.1B-Chat-v1.0" )],
0 commit comments