Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,34 @@ def rescale_output_resolution(
self.width = torch.ceil(self.width * scaling_factor).to(torch.int64)
else:
raise ValueError("Scale rounding mode must be 'floor', 'round' or 'ceil'.")

def update_tiling_intrinsics(self, tiling_factor: int) -> None:
"""
Update camera intrinsics based on tiling_factor.
Must match tiling logic as defined in dataparser.

Args:
tiling_factor: Tiling factor to apply to the camera intrinsics.
"""
if tiling_factor == 1:
return

num_tiles = tiling_factor**2

# Compute tile sizes
base_tile_w, remainder_w = self.width // tiling_factor, self.width % tiling_factor
base_tile_h, remainder_h = self.height // tiling_factor, self.height % tiling_factor

tile_indices = torch.arange(len(self.cx), device=self.cx.device).unsqueeze(1) % num_tiles
row_indices, col_indices = tile_indices // tiling_factor, tile_indices % tiling_factor

x_offsets = col_indices * base_tile_w + torch.minimum(col_indices, remainder_w)
y_offsets = row_indices * base_tile_h + torch.minimum(row_indices, remainder_h)

# Adjust principal points
self.cx = self.cx - x_offsets
self.cy = self.cy - y_offsets

# Adjust height/width
self.width = base_tile_w + (col_indices < remainder_w).to(torch.int)
self.height = base_tile_h + (row_indices < remainder_h).to(torch.int)
144 changes: 132 additions & 12 deletions nerfstudio/data/dataparsers/colmap_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class ColmapDataParserConfig(DataParserConfig):
"""How much to downscale images. If not set, images are chosen such that the max dimension is <1600px."""
downscale_rounding_mode: Literal["floor", "round", "ceil"] = "floor"
"""How to round downscale image height and Image width."""
tiling_factor: int = 1
"""Tile images into n^2 equal-resolution images, where n is this number. n | H, n | W for image with resolution WxH"""
scene_scale: float = 1.0
"""How much to scale the region of interest by."""
orientation_method: Literal["pca", "up", "vertical", "none"] = "up"
Expand Down Expand Up @@ -115,7 +117,8 @@ class ColmapDataParser(DataParser):

The dataparser loads the downscaled images from folders with `_{downscale_factor}` suffix.
If these folders do not exist, the user can choose to automatically downscale the images and
create these folders.
create these folders. If tiling_factor > 1, the images are instead loaded from folders with
`_tiled_{tiling_factor}` suffix.

The loader is compatible with the datasets processed using the ns-process-data script and
can be used as a drop-in replacement. It further supports datasets like Mip-NeRF 360 (although
Expand Down Expand Up @@ -327,13 +330,26 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
image_filenames, mask_filenames, depth_filenames, downscale_factor = self._setup_downscale_factor(
image_filenames, mask_filenames, depth_filenames
)
image_filenames, mask_filenames, depth_filenames = self._setup_tiling(
image_filenames, mask_filenames, depth_filenames
)

num_tiles = self.config.tiling_factor**2

image_filenames = [image_filenames[i] for i in indices]
mask_filenames = [mask_filenames[i] for i in indices] if len(mask_filenames) > 0 else []
depth_filenames = [depth_filenames[i] for i in indices] if len(depth_filenames) > 0 else []
image_filenames = [image_filenames[i * num_tiles + j] for i in indices for j in range(num_tiles)]
mask_filenames = (
[mask_filenames[i * num_tiles + j] for i in indices for j in range(num_tiles)]
if len(mask_filenames) > 0
else []
)
depth_filenames = (
[depth_filenames[i * num_tiles + j] for i in indices for j in range(num_tiles)]
if len(depth_filenames) > 0
else []
)

idx_tensor = torch.tensor(indices, dtype=torch.long)
poses = poses[idx_tensor]
poses = poses[idx_tensor].repeat_interleave(num_tiles, dim=0)

# in x,y,z order
# assumes that the scene is centered at the origin
Expand All @@ -344,13 +360,13 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
)
)

fx = torch.tensor(fx, dtype=torch.float32)[idx_tensor]
fy = torch.tensor(fy, dtype=torch.float32)[idx_tensor]
cx = torch.tensor(cx, dtype=torch.float32)[idx_tensor]
cy = torch.tensor(cy, dtype=torch.float32)[idx_tensor]
height = torch.tensor(height, dtype=torch.int32)[idx_tensor]
width = torch.tensor(width, dtype=torch.int32)[idx_tensor]
distortion_params = torch.stack(distort, dim=0)[idx_tensor]
fx = torch.tensor(fx, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles)
fy = torch.tensor(fy, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles)
cx = torch.tensor(cx, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles)
cy = torch.tensor(cy, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles)
height = torch.tensor(height, dtype=torch.int32)[idx_tensor].repeat_interleave(num_tiles)
width = torch.tensor(width, dtype=torch.int32)[idx_tensor].repeat_interleave(num_tiles)
distortion_params = torch.stack(distort, dim=0)[idx_tensor].repeat_interleave(num_tiles, dim=0)

cameras = Cameras(
fx=fx,
Expand All @@ -364,6 +380,7 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
camera_type=camera_type,
)

cameras.update_tiling_intrinsics(tiling_factor=self.config.tiling_factor)
cameras.rescale_output_resolution(
scaling_factor=1.0 / downscale_factor, scale_rounding_mode=self.config.downscale_rounding_mode
)
Expand Down Expand Up @@ -464,6 +481,109 @@ def _load_3D_points(self, colmap_path: Path, transform_matrix: torch.Tensor, sca
out["points3D_points2D_xy"] = torch.stack(points3D_image_xy, dim=0)
return out

def _tile_images(self, paths, get_fname, tiling_factor):
"""
Tile images into self.tiling_factor^2 tiles.
Logic must match intrinsics update in Cameras object.
"""
with status(msg="[bold yellow]Tiling images...", spinner="growVertical"):
assert isinstance(tiling_factor, int)
assert tiling_factor > 1

for path in paths:
img = Image.open(path)
w, h = img.size

base_tile_w, remainder_w = divmod(w, tiling_factor)
base_tile_h, remainder_h = divmod(h, tiling_factor)

path_out_base = get_fname(path)
path_out_base.parent.mkdir(parents=True, exist_ok=True)

for row in range(tiling_factor):
for col in range(tiling_factor):
idx = row * tiling_factor + col

# Distribute the remainder among the first remainder_w columns and remainder_h rows
tile_w = base_tile_w + int(col < remainder_w)
tile_h = base_tile_h + int(row < remainder_h)

x_offset = col * base_tile_w + min(col, remainder_w)
y_offset = row * base_tile_h + min(row, remainder_h)

tile = img.crop(
(
x_offset,
y_offset,
x_offset + tile_w,
y_offset + tile_h,
)
)

output_path = path_out_base.with_stem(path_out_base.stem + f"_{idx}")
tile.save(output_path)

CONSOLE.log("[bold green]:tada: Done tiling images.")

def _setup_tiling(self, image_filenames: List[Path], mask_filenames: List[Path], depth_filenames: List[Path]):
"""
Wrapper around self._tile_images() to handle tiling of image, mask, and depth files.
"""
if self.config.tiling_factor == 1:
return image_filenames, mask_filenames, depth_filenames

assert self._downscale_factor == 1, "Tiling not supported with downscaling, please set --downscale_factor=1"

def get_fname(parent: Path, filepath: Path) -> Path:
"""Returns transformed file name when tiling factor is applied"""
rel_part = filepath.relative_to(parent)
base_part = parent.parent / (str(parent.name) + f"_tiled_{self.config.tiling_factor}")
return base_part / rel_part

if not all(get_fname(self.config.data / self.config.images_path, fp).parent.exists() for fp in image_filenames):
self._tile_images(
image_filenames,
partial(get_fname, self.config.data / self.config.images_path),
self.config.tiling_factor,
)
if len(mask_filenames) > 0:
assert self.config.masks_path is not None
self._tile_images(
mask_filenames,
partial(get_fname, self.config.data / self.config.masks_path),
self.config.tiling_factor,
)
if len(depth_filenames) > 0:
assert self.config.depths_path is not None
self._tile_images(
depth_filenames,
partial(get_fname, self.config.data / self.config.depths_path),
self.config.tiling_factor,
)

num_tiles = self.config.tiling_factor**2
image_filenames = [
get_fname(self.config.data / self.config.images_path, fp.with_stem(fp.stem + f"_{i}"))
for fp in image_filenames
for i in range(num_tiles)
]
if len(mask_filenames) > 0:
assert self.config.masks_path is not None
mask_filenames = [
get_fname(self.config.data / self.config.masks_path, fp.with_stem(fp.stem + f"_{i}"))
for fp in mask_filenames
for i in range(num_tiles)
]
if len(depth_filenames) > 0:
assert self.config.depths_path is not None
depth_filenames = [
get_fname(self.config.data / self.config.depths_path, fp.with_stem(fp.stem + f"_{i}"))
for fp in depth_filenames
for i in range(num_tiles)
]

return image_filenames, mask_filenames, depth_filenames

def _downscale_images(
self,
paths,
Expand Down