38
38
39
39
logger = logging .getLogger (__name__ )
40
40
41
+ # Default value for num_trials argument
42
+ DEFAULT_NUM_TRIALS = 10
43
+ MAX_TRIALS = 100 # Maximum trials for prompt generation, warn if exceeded
44
+
41
45
42
46
def return_random_image_by_size (width : int , height : int , convert_to_base64 : bool = False ) -> Any :
43
47
@@ -198,6 +202,8 @@ def generate_prompts(
198
202
"User selected sharegpt dataset. "
199
203
"Ignoring prompt length distribution and following the prompts from the dataset."
200
204
)
205
+ if args .num_trials != DEFAULT_NUM_TRIALS : # Check if user specified custom value
206
+ logger .warning ("num_trials parameter is ignored for ShareGPT dataset as prompts are pre-defined" )
201
207
prompt_cls = ShareGPT (filename , tokenizer , output_token_dist )
202
208
else :
203
209
logger .info (f"User selected { args .dataset_name } dataset. Generating prompt from distributions." )
@@ -216,7 +222,12 @@ def generate_prompts(
216
222
if args .prefix_len :
217
223
prompt_cls = (
218
224
Random .with_prefix_len (
219
- args .prefix_len , input_prompt_dist , output_token_dist , tokenizer , args .ignore_input_distribution
225
+ args .prefix_len ,
226
+ input_prompt_dist ,
227
+ output_token_dist ,
228
+ tokenizer ,
229
+ args .ignore_input_distribution ,
230
+ args .num_trials ,
220
231
)
221
232
if args .dataset_name == "random"
222
233
else Textfile .with_prefix_len (
@@ -226,13 +237,19 @@ def generate_prompts(
226
237
output_token_dist ,
227
238
tokenizer ,
228
239
args .ignore_input_distribution ,
240
+ args .num_trials ,
229
241
)
230
242
)
231
243
else :
232
244
prefix_text = args .prefix_text or ""
233
245
prompt_cls = (
234
246
Random .with_prefix_str (
235
- prefix_text , input_prompt_dist , output_token_dist , tokenizer , args .ignore_input_distribution
247
+ prefix_text ,
248
+ input_prompt_dist ,
249
+ output_token_dist ,
250
+ tokenizer ,
251
+ args .ignore_input_distribution ,
252
+ args .num_trials ,
236
253
)
237
254
if args .dataset_name == "random"
238
255
else Textfile .with_prefix_str (
@@ -242,6 +259,7 @@ def generate_prompts(
242
259
output_token_dist ,
243
260
tokenizer ,
244
261
args .ignore_input_distribution ,
262
+ args .num_trials ,
245
263
)
246
264
)
247
265
@@ -492,6 +510,15 @@ def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> Any: # t
492
510
help = "Number of input tokens to use for validation prompts (default: 128)." ,
493
511
)
494
512
513
+ benchmark_parser .add_argument (
514
+ "--num-trials" ,
515
+ type = int ,
516
+ default = DEFAULT_NUM_TRIALS ,
517
+ help = "Number of attempts to achieve exact token count when generating prompts (default: 10). "
518
+ "Used for 'random' and 'other' datasets. Higher values improve token count precision "
519
+ "but may slow down prompt generation. Ignored for ShareGPT datasets." ,
520
+ )
521
+
495
522
return benchmark_parser
496
523
497
524
@@ -604,6 +631,12 @@ def fail(msg: str) -> None:
604
631
if args .dataset_path and not args .dataset_name :
605
632
args .dataset_name = "other"
606
633
634
+ # Validate num_trials parameter
635
+ if args .num_trials <= 0 :
636
+ fail ("Number of trials must be positive" )
637
+ if args .num_trials > MAX_TRIALS :
638
+ logger .warning (f"High num_trials value ({ args .num_trials } ) may slow down prompt generation" )
639
+
607
640
return args
608
641
609
642
0 commit comments