@@ -661,19 +661,22 @@ def test_trtllm_bench_mig_launch(llm_root, llm_venv, model_name, model_subdir,
661661 print ("-" * 60 )
662662
663663 for idx , val in enumerate (concurrency_list ):
664- if hasattr (results [val ], 'get' ):
665- throughput = float (results [val ].get ('throughput' , 0 ))
666- latency = float (results [val ].get ('latency' , 0 ))
667- num_requests = int (results [val ].get ('num_requests' , 0 ))
668- assert throughput > 0 , f"Throughput is 0 for concurrency { concurrency } "
669- assert latency > 0 , f"Latency is 0 for concurrency { concurrency } "
670- print (
671- f"{ concurrency :<15} { throughput :<15} { latency :<15} { num_requests :<15} "
672- )
673- if idx > 0 :
674- assert throughput > float (
675- results [concurrency_list [idx - 1 ]].get ('throughput' , 0 )
676- ) * 1.3 , f"Throughput is not increasing for concurrency { concurrency_list [idx ]} "
664+ for idx , val in enumerate (concurrency_list ):
665+ metrics = results .get (val )
666+ if not isinstance (metrics , dict ):
667+ pytest .fail (f"Unexpected benchmark result type for concurrency { val } : { type (metrics )} " )
668+ try :
669+ throughput = float (metrics .get ('throughput' , 0 ))
670+ latency = float (metrics .get ('latency' , 0 ))
671+ num_requests = int (metrics .get ('num_requests' , 0 ))
672+ except (ValueError , TypeError ) as e :
673+ pytest .fail (f"Failed to parse benchmark results for concurrency { val } : { e } " )
674+ assert throughput > 0 , f"Throughput is 0 for concurrency { val } "
675+ assert latency > 0 , f"Latency is 0 for concurrency { val } "
676+ print (f"{ val :<15} { throughput :<15} { latency :<15} { num_requests :<15} " )
677+ if idx > 0 :
678+ prev_throughput = float (results [concurrency_list [idx - 1 ]].get ('throughput' , 0 ))
679+ assert throughput > prev_throughput * 1.3 , f"Throughput is not increasing for concurrency { concurrency_list [idx ]} "
677680
678681
679682@pytest .mark .parametrize (
0 commit comments