Skip to content

Commit 03eded6

Browse files
Add num-trials as cli arg (#109)
* add num-trials as cli arg * add test and format * fix test * use global var for default num trials * implement copilot suggestions
1 parent fe8ba07 commit 03eded6

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

src/flexible_inference_benchmark/main.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838

3939
logger = logging.getLogger(__name__)
4040

41+
# Default value for num_trials argument
42+
DEFAULT_NUM_TRIALS = 10
43+
MAX_TRIALS = 100 # Maximum trials for prompt generation, warn if exceeded
44+
4145

4246
def return_random_image_by_size(width: int, height: int, convert_to_base64: bool = False) -> Any:
4347

@@ -198,6 +202,8 @@ def generate_prompts(
198202
"User selected sharegpt dataset. "
199203
"Ignoring prompt length distribution and following the prompts from the dataset."
200204
)
205+
if args.num_trials != DEFAULT_NUM_TRIALS: # Check if user specified custom value
206+
logger.warning("num_trials parameter is ignored for ShareGPT dataset as prompts are pre-defined")
201207
prompt_cls = ShareGPT(filename, tokenizer, output_token_dist)
202208
else:
203209
logger.info(f"User selected {args.dataset_name} dataset. Generating prompt from distributions.")
@@ -216,7 +222,12 @@ def generate_prompts(
216222
if args.prefix_len:
217223
prompt_cls = (
218224
Random.with_prefix_len(
219-
args.prefix_len, input_prompt_dist, output_token_dist, tokenizer, args.ignore_input_distribution
225+
args.prefix_len,
226+
input_prompt_dist,
227+
output_token_dist,
228+
tokenizer,
229+
args.ignore_input_distribution,
230+
args.num_trials,
220231
)
221232
if args.dataset_name == "random"
222233
else Textfile.with_prefix_len(
@@ -226,13 +237,19 @@ def generate_prompts(
226237
output_token_dist,
227238
tokenizer,
228239
args.ignore_input_distribution,
240+
args.num_trials,
229241
)
230242
)
231243
else:
232244
prefix_text = args.prefix_text or ""
233245
prompt_cls = (
234246
Random.with_prefix_str(
235-
prefix_text, input_prompt_dist, output_token_dist, tokenizer, args.ignore_input_distribution
247+
prefix_text,
248+
input_prompt_dist,
249+
output_token_dist,
250+
tokenizer,
251+
args.ignore_input_distribution,
252+
args.num_trials,
236253
)
237254
if args.dataset_name == "random"
238255
else Textfile.with_prefix_str(
@@ -242,6 +259,7 @@ def generate_prompts(
242259
output_token_dist,
243260
tokenizer,
244261
args.ignore_input_distribution,
262+
args.num_trials,
245263
)
246264
)
247265

@@ -492,6 +510,15 @@ def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> Any: # t
492510
help="Number of input tokens to use for validation prompts (default: 128).",
493511
)
494512

513+
benchmark_parser.add_argument(
514+
"--num-trials",
515+
type=int,
516+
default=DEFAULT_NUM_TRIALS,
517+
help="Number of attempts to achieve exact token count when generating prompts (default: 10). "
518+
"Used for 'random' and 'other' datasets. Higher values improve token count precision "
519+
"but may slow down prompt generation. Ignored for ShareGPT datasets.",
520+
)
521+
495522
return benchmark_parser
496523

497524

@@ -604,6 +631,12 @@ def fail(msg: str) -> None:
604631
if args.dataset_path and not args.dataset_name:
605632
args.dataset_name = "other"
606633

634+
# Validate num_trials parameter
635+
if args.num_trials <= 0:
636+
fail("Number of trials must be positive")
637+
if args.num_trials > MAX_TRIALS:
638+
logger.warning(f"High num_trials value ({args.num_trials}) may slow down prompt generation")
639+
607640
return args
608641

609642

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def args_configs():
6969
"temperature": 0.0,
7070
"top_p": None,
7171
"top_k": None,
72+
"num_trials": 8,
7273
}
7374

7475
sharegpt_sample_data_path = "tests/data/sharegpt_sample_test_data.json"

tests/test_data.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import flexible_inference_benchmark.engine.data as data
88
import flexible_inference_benchmark.engine.distributions as distributions
9+
from flexible_inference_benchmark.main import parse_args
910
from sharegpt_data import SHAREGPT_DATA
1011

1112
@pytest.mark.parametrize("ignore_input_distribution", [True, False])
@@ -57,3 +58,32 @@ def test_sharegpt():
5758
if os.path.exists("sharegpt_test.json"):
5859
os.remove("sharegpt_test.json")
5960
assert random_data.shape == (10,3)
61+
62+
def test_num_trials_cli_argument():
63+
"""Test that num_trials CLI argument is properly parsed and validated."""
64+
import sys
65+
66+
# Test default value
67+
original_argv = sys.argv
68+
try:
69+
sys.argv = ['fib', 'benchmark', '--model', 'test', '--base-url', 'http://test']
70+
args = parse_args()
71+
assert args.num_trials == 10
72+
73+
# Test custom value
74+
sys.argv = ['fib', 'benchmark', '--model', 'test', '--base-url', 'http://test', '--num-trials', '5']
75+
args = parse_args()
76+
assert args.num_trials == 5
77+
78+
# Test validation - zero value should fail
79+
sys.argv = ['fib', 'benchmark', '--model', 'test', '--base-url', 'http://test', '--num-trials', '0']
80+
with pytest.raises(SystemExit):
81+
parse_args()
82+
83+
# Test validation - negative value should fail
84+
sys.argv = ['fib', 'benchmark', '--model', 'test', '--base-url', 'http://test', '--num-trials', '-1']
85+
with pytest.raises(SystemExit):
86+
parse_args()
87+
88+
finally:
89+
sys.argv = original_argv

0 commit comments

Comments
 (0)