diff --git a/nerfstudio/cameras/cameras.py b/nerfstudio/cameras/cameras.py index b1561e65b2..e971b41f61 100644 --- a/nerfstudio/cameras/cameras.py +++ b/nerfstudio/cameras/cameras.py @@ -1021,3 +1021,34 @@ def rescale_output_resolution( self.width = torch.ceil(self.width * scaling_factor).to(torch.int64) else: raise ValueError("Scale rounding mode must be 'floor', 'round' or 'ceil'.") + + def update_tiling_intrinsics(self, tiling_factor: int) -> None: + """ + Update camera intrinsics based on tiling_factor. + Must match tiling logic as defined in dataparser. + + Args: + tiling_factor: Tiling factor to apply to the camera intrinsics. + """ + if tiling_factor == 1: + return + + num_tiles = tiling_factor**2 + + # Compute tile sizes + base_tile_w, remainder_w = self.width // tiling_factor, self.width % tiling_factor + base_tile_h, remainder_h = self.height // tiling_factor, self.height % tiling_factor + + tile_indices = torch.arange(len(self.cx), device=self.cx.device).unsqueeze(1) % num_tiles + row_indices, col_indices = tile_indices // tiling_factor, tile_indices % tiling_factor + + x_offsets = col_indices * base_tile_w + torch.minimum(col_indices, remainder_w) + y_offsets = row_indices * base_tile_h + torch.minimum(row_indices, remainder_h) + + # Adjust principal points + self.cx = self.cx - x_offsets + self.cy = self.cy - y_offsets + + # Adjust height/width + self.width = base_tile_w + (col_indices < remainder_w).to(torch.int) + self.height = base_tile_h + (row_indices < remainder_h).to(torch.int) diff --git a/nerfstudio/data/dataparsers/colmap_dataparser.py b/nerfstudio/data/dataparsers/colmap_dataparser.py index e9bfd4bb4a..dede3a3be5 100644 --- a/nerfstudio/data/dataparsers/colmap_dataparser.py +++ b/nerfstudio/data/dataparsers/colmap_dataparser.py @@ -59,6 +59,8 @@ class ColmapDataParserConfig(DataParserConfig): """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" downscale_rounding_mode: Literal["floor", "round", "ceil"] = "floor" """How to round downscale image height and Image width.""" + tiling_factor: int = 1 + """Tile images into n^2 equal-resolution images, where n is this number. n | H, n | W for image with resolution WxH""" scene_scale: float = 1.0 """How much to scale the region of interest by.""" orientation_method: Literal["pca", "up", "vertical", "none"] = "up" @@ -115,7 +117,8 @@ class ColmapDataParser(DataParser): The dataparser loads the downscaled images from folders with `_{downscale_factor}` suffix. If these folders do not exist, the user can choose to automatically downscale the images and - create these folders. + create these folders. If tiling_factor > 1, the images are instead loaded from folders with + `_tiled_{tiling_factor}` suffix. The loader is compatible with the datasets processed using the ns-process-data script and can be used as a drop-in replacement. It further supports datasets like Mip-NeRF 360 (although @@ -327,13 +330,26 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs): image_filenames, mask_filenames, depth_filenames, downscale_factor = self._setup_downscale_factor( image_filenames, mask_filenames, depth_filenames ) + image_filenames, mask_filenames, depth_filenames = self._setup_tiling( + image_filenames, mask_filenames, depth_filenames + ) + + num_tiles = self.config.tiling_factor**2 - image_filenames = [image_filenames[i] for i in indices] - mask_filenames = [mask_filenames[i] for i in indices] if len(mask_filenames) > 0 else [] - depth_filenames = [depth_filenames[i] for i in indices] if len(depth_filenames) > 0 else [] + image_filenames = [image_filenames[i * num_tiles + j] for i in indices for j in range(num_tiles)] + mask_filenames = ( + [mask_filenames[i * num_tiles + j] for i in indices for j in range(num_tiles)] + if len(mask_filenames) > 0 + else [] + ) + depth_filenames = ( + [depth_filenames[i * num_tiles + j] for i in indices for j in range(num_tiles)] + if len(depth_filenames) > 0 + else [] + ) idx_tensor = torch.tensor(indices, dtype=torch.long) - poses = poses[idx_tensor] + poses = poses[idx_tensor].repeat_interleave(num_tiles, dim=0) # in x,y,z order # assumes that the scene is centered at the origin @@ -344,13 +360,13 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs): ) ) - fx = torch.tensor(fx, dtype=torch.float32)[idx_tensor] - fy = torch.tensor(fy, dtype=torch.float32)[idx_tensor] - cx = torch.tensor(cx, dtype=torch.float32)[idx_tensor] - cy = torch.tensor(cy, dtype=torch.float32)[idx_tensor] - height = torch.tensor(height, dtype=torch.int32)[idx_tensor] - width = torch.tensor(width, dtype=torch.int32)[idx_tensor] - distortion_params = torch.stack(distort, dim=0)[idx_tensor] + fx = torch.tensor(fx, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles) + fy = torch.tensor(fy, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles) + cx = torch.tensor(cx, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles) + cy = torch.tensor(cy, dtype=torch.float32)[idx_tensor].repeat_interleave(num_tiles) + height = torch.tensor(height, dtype=torch.int32)[idx_tensor].repeat_interleave(num_tiles) + width = torch.tensor(width, dtype=torch.int32)[idx_tensor].repeat_interleave(num_tiles) + distortion_params = torch.stack(distort, dim=0)[idx_tensor].repeat_interleave(num_tiles, dim=0) cameras = Cameras( fx=fx, @@ -364,6 +380,7 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs): camera_type=camera_type, ) + cameras.update_tiling_intrinsics(tiling_factor=self.config.tiling_factor) cameras.rescale_output_resolution( scaling_factor=1.0 / downscale_factor, scale_rounding_mode=self.config.downscale_rounding_mode ) @@ -464,6 +481,109 @@ def _load_3D_points(self, colmap_path: Path, transform_matrix: torch.Tensor, sca out["points3D_points2D_xy"] = torch.stack(points3D_image_xy, dim=0) return out + def _tile_images(self, paths, get_fname, tiling_factor): + """ + Tile images into self.tiling_factor^2 tiles. + Logic must match intrinsics update in Cameras object. + """ + with status(msg="[bold yellow]Tiling images...", spinner="growVertical"): + assert isinstance(tiling_factor, int) + assert tiling_factor > 1 + + for path in paths: + img = Image.open(path) + w, h = img.size + + base_tile_w, remainder_w = divmod(w, tiling_factor) + base_tile_h, remainder_h = divmod(h, tiling_factor) + + path_out_base = get_fname(path) + path_out_base.parent.mkdir(parents=True, exist_ok=True) + + for row in range(tiling_factor): + for col in range(tiling_factor): + idx = row * tiling_factor + col + + # Distribute the remainder among the first remainder_w columns and remainder_h rows + tile_w = base_tile_w + int(col < remainder_w) + tile_h = base_tile_h + int(row < remainder_h) + + x_offset = col * base_tile_w + min(col, remainder_w) + y_offset = row * base_tile_h + min(row, remainder_h) + + tile = img.crop( + ( + x_offset, + y_offset, + x_offset + tile_w, + y_offset + tile_h, + ) + ) + + output_path = path_out_base.with_stem(path_out_base.stem + f"_{idx}") + tile.save(output_path) + + CONSOLE.log("[bold green]:tada: Done tiling images.") + + def _setup_tiling(self, image_filenames: List[Path], mask_filenames: List[Path], depth_filenames: List[Path]): + """ + Wrapper around self._tile_images() to handle tiling of image, mask, and depth files. + """ + if self.config.tiling_factor == 1: + return image_filenames, mask_filenames, depth_filenames + + assert self._downscale_factor == 1, "Tiling not supported with downscaling, please set --downscale_factor=1" + + def get_fname(parent: Path, filepath: Path) -> Path: + """Returns transformed file name when tiling factor is applied""" + rel_part = filepath.relative_to(parent) + base_part = parent.parent / (str(parent.name) + f"_tiled_{self.config.tiling_factor}") + return base_part / rel_part + + if not all(get_fname(self.config.data / self.config.images_path, fp).parent.exists() for fp in image_filenames): + self._tile_images( + image_filenames, + partial(get_fname, self.config.data / self.config.images_path), + self.config.tiling_factor, + ) + if len(mask_filenames) > 0: + assert self.config.masks_path is not None + self._tile_images( + mask_filenames, + partial(get_fname, self.config.data / self.config.masks_path), + self.config.tiling_factor, + ) + if len(depth_filenames) > 0: + assert self.config.depths_path is not None + self._tile_images( + depth_filenames, + partial(get_fname, self.config.data / self.config.depths_path), + self.config.tiling_factor, + ) + + num_tiles = self.config.tiling_factor**2 + image_filenames = [ + get_fname(self.config.data / self.config.images_path, fp.with_stem(fp.stem + f"_{i}")) + for fp in image_filenames + for i in range(num_tiles) + ] + if len(mask_filenames) > 0: + assert self.config.masks_path is not None + mask_filenames = [ + get_fname(self.config.data / self.config.masks_path, fp.with_stem(fp.stem + f"_{i}")) + for fp in mask_filenames + for i in range(num_tiles) + ] + if len(depth_filenames) > 0: + assert self.config.depths_path is not None + depth_filenames = [ + get_fname(self.config.data / self.config.depths_path, fp.with_stem(fp.stem + f"_{i}")) + for fp in depth_filenames + for i in range(num_tiles) + ] + + return image_filenames, mask_filenames, depth_filenames + def _downscale_images( self, paths,