Skip to content

Commit 688ba70

Browse files
committed
Use mask rejection sampling in PatchPixelSampler
1 parent 4c7297c commit 688ba70

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

nerfstudio/data/pixel_samplers.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ class PatchPixelSamplerConfig(PixelSamplerConfig):
420420
patch_size: int = 32
421421
"""Side length of patch. This must be consistent in the method
422422
config in order for samples to be reshaped into patches correctly."""
423+
rejection_sample_mask: bool = True
424+
"""Whether or not to use rejection sampling when sampling images with masks"""
425+
max_num_iterations: int = 100
426+
"""If rejection sampling masks, the maximum number of times to sample"""
423427

424428

425429
class PatchPixelSampler(PixelSampler):
@@ -458,9 +462,20 @@ def sample_method(
458462
sub_bs = batch_size // (self.config.patch_size**2)
459463
half_patch_size = int(self.config.patch_size / 2)
460464
m = erode_mask(mask.permute(0, 3, 1, 2).float(), pixel_radius=half_patch_size)
461-
nonzero_indices = torch.nonzero(m[:, 0], as_tuple=False).to(device)
462-
chosen_indices = random.sample(range(len(nonzero_indices)), k=sub_bs)
463-
indices = nonzero_indices[chosen_indices]
465+
466+
if self.config.rejection_sample_mask:
467+
indices = self.rejection_sample_mask(
468+
mask=m,
469+
num_samples=sub_bs,
470+
num_images=num_images,
471+
image_height=image_height,
472+
image_width=image_width,
473+
device=device,
474+
)
475+
else:
476+
nonzero_indices = torch.nonzero(m[:, 0], as_tuple=False).to(device)
477+
chosen_indices = random.sample(range(len(nonzero_indices)), k=sub_bs)
478+
indices = nonzero_indices[chosen_indices]
464479

465480
indices = (
466481
indices.view(sub_bs, 1, 1, 3)

0 commit comments

Comments
 (0)