Skip to content

IDTReeS - Clip boxes outside of image bounds #760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 7, 2022
Merged
67 changes: 63 additions & 4 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload

import fiona
import matplotlib.pyplot as plt
Expand All @@ -14,6 +14,7 @@
import torch
from rasterio.enums import Resampling
from torch import Tensor
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
from torchvision.utils import draw_bounding_boxes

from .geo import NonGeoDataset
Expand Down Expand Up @@ -211,10 +212,22 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
if self.split == "test":
if self.task == "task2":
sample["boxes"] = self._load_boxes(path)
h, w = sample["image"].shape[1:]
sample["boxes"], _ = self._filter_boxes(
image_size=(h, w), min_size=1, boxes=sample["boxes"], labels=None
)
else:
sample["boxes"] = self._load_boxes(path)
sample["label"] = self._load_target(path)

h, w = sample["image"].shape[1:]
sample["boxes"], sample["label"] = self._filter_boxes(
image_size=(h, w),
min_size=1,
boxes=sample["boxes"],
labels=sample["label"],
)

if self.transforms is not None:
sample = self.transforms(sample)

Expand Down Expand Up @@ -271,11 +284,15 @@ def _load_boxes(self, path: str) -> Tensor:
geometries = cast(Dict[int, Dict[str, Any]], self.geometries)

# Find object ids and geometries
# The train set geometry->image mapping is contained
# in the train/Field/itc_rsFile.csv file
if self.split == "train":
indices = self.labels["rsFile"] == base_path
ids = self.labels[indices]["id"].tolist()
geoms = [geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
# Test set - Task 2 has no mapping csv. Mapping is inside of geometry
# The test set has no mapping csv. The mapping is inside of the geometry
# properties i.e. geom["property"]["plotID"] contains the RGB image filename
# Return all geometries with the matching RGB image filename of the sample
else:
ids = [
k
Expand Down Expand Up @@ -380,17 +397,59 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
"""
filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))

i = 0
features: Dict[int, Dict[str, Any]] = {}
for path in filepaths:
with fiona.open(path) as src:
for i, feature in enumerate(src):
for feature in src:
# The train set has a unique id for each geometry in the properties
if self.split == "train":
features[feature["properties"]["id"]] = feature
# Test set task 2 has no id
# The test set has no unique id so create a dummy id
else:
features[i] = feature
i += 1
return features

@overload
def _filter_boxes(
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor
) -> Tuple[Tensor, Tensor]:
...

@overload
def _filter_boxes(
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: None
) -> Tuple[Tensor, None]:
...

def _filter_boxes(
self,
image_size: Tuple[int, int],
min_size: int,
boxes: Tensor,
labels: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.

Args:
image_size: tuple of (height, width) of image
min_size: filter boxes that have any side less than min_size
boxes: [N, 4] shape tensor of xyxy bounding box coordinates
labels: (Optional) [N,] shape tensor of bounding box labels

Returns:
a tuple of filtered boxes and labels
"""
boxes = clip_boxes_to_image(boxes=boxes, size=image_size)
indices = remove_small_boxes(boxes=boxes, min_size=min_size)

boxes = boxes[indices]
if labels is not None:
labels = labels[indices]

return boxes, labels

def _verify(self) -> None:
"""Verify the integrity of the dataset.

Expand Down