@@ -376,6 +376,14 @@ def run(args):
376376 if args .subsample_seed :
377377 random .seed (args .subsample_seed )
378378 num_excluded_subsamp = 0
379+
380+ # Disable probabilistic sampling when user's request a specific number of
381+ # sequences per group. In this case, users expect deterministic behavior and
382+ # probabilistic behavior is surprising.
383+ probabilistic_sampling = args .probabilistic_sampling
384+ if args .sequences_per_group :
385+ probabilistic_sampling = False
386+
379387 if args .subsample_max_sequences or (args .group_by and args .sequences_per_group ):
380388
381389 #set groups to group_by values
@@ -448,7 +456,7 @@ def run(args):
448456 for sequences_in_group in seq_names_by_group .values ()
449457 ]
450458
451- if args . probabilistic_sampling :
459+ if probabilistic_sampling :
452460 spg = _calculate_fractional_sequences_per_group (
453461 args .subsample_max_sequences ,
454462 length_of_sequences_per_group
@@ -463,7 +471,7 @@ def run(args):
463471 sys .exit (1 )
464472 print ("sampling at {} per group." .format (spg ))
465473
466- if args . probabilistic_sampling :
474+ if probabilistic_sampling :
467475 random_generator = np .random .default_rng ()
468476
469477 # subsample each groups, either by taking the spg highest priority strains or
@@ -480,7 +488,7 @@ def run(args):
480488 subsampling_attempts += 1
481489
482490 for group , sequences_in_group in seq_names_by_group .items ():
483- if args . probabilistic_sampling :
491+ if probabilistic_sampling :
484492 tmp_spg = random_generator .poisson (spg )
485493 else :
486494 tmp_spg = spg
0 commit comments