@@ -420,6 +420,10 @@ class PatchPixelSamplerConfig(PixelSamplerConfig):
420
420
patch_size : int = 32
421
421
"""Side length of patch. This must be consistent in the method
422
422
config in order for samples to be reshaped into patches correctly."""
423
+ rejection_sample_mask : bool = True
424
+ """Whether or not to use rejection sampling when sampling images with masks"""
425
+ max_num_iterations : int = 100
426
+ """If rejection sampling masks, the maximum number of times to sample"""
423
427
424
428
425
429
class PatchPixelSampler (PixelSampler ):
@@ -458,9 +462,20 @@ def sample_method(
458
462
sub_bs = batch_size // (self .config .patch_size ** 2 )
459
463
half_patch_size = int (self .config .patch_size / 2 )
460
464
m = erode_mask (mask .permute (0 , 3 , 1 , 2 ).float (), pixel_radius = half_patch_size )
461
- nonzero_indices = torch .nonzero (m [:, 0 ], as_tuple = False ).to (device )
462
- chosen_indices = random .sample (range (len (nonzero_indices )), k = sub_bs )
463
- indices = nonzero_indices [chosen_indices ]
465
+
466
+ if self .config .rejection_sample_mask :
467
+ indices = self .rejection_sample_mask (
468
+ mask = m ,
469
+ num_samples = sub_bs ,
470
+ num_images = num_images ,
471
+ image_height = image_height ,
472
+ image_width = image_width ,
473
+ device = device ,
474
+ )
475
+ else :
476
+ nonzero_indices = torch .nonzero (m [:, 0 ], as_tuple = False ).to (device )
477
+ chosen_indices = random .sample (range (len (nonzero_indices )), k = sub_bs )
478
+ indices = nonzero_indices [chosen_indices ]
464
479
465
480
indices = (
466
481
indices .view (sub_bs , 1 , 1 , 3 )
0 commit comments