Skip to content

Commit 212d9b9

Browse files
authored
IDTReeS - Clip boxes outside of image bounds (#760)
* reverse idtrees coords * clip boxes to image bounds * undo coordinate reversal * add filter_boxes function * filter boxes outside of bounds in idtrees * format * Revert utils.py * flake8 fixes * fix mypy errors * fix bug overriding some labels * fix image size * Remove version added line * add function overloads * add comments for clarity * use id counter for test set
1 parent f6a9f75 commit 212d9b9

File tree

1 file changed

+63
-4
lines changed

1 file changed

+63
-4
lines changed

torchgeo/datasets/idtrees.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import glob
77
import os
8-
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload
99

1010
import fiona
1111
import matplotlib.pyplot as plt
@@ -14,6 +14,7 @@
1414
import torch
1515
from rasterio.enums import Resampling
1616
from torch import Tensor
17+
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
1718
from torchvision.utils import draw_bounding_boxes
1819

1920
from .geo import NonGeoDataset
@@ -211,10 +212,22 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
211212
if self.split == "test":
212213
if self.task == "task2":
213214
sample["boxes"] = self._load_boxes(path)
215+
h, w = sample["image"].shape[1:]
216+
sample["boxes"], _ = self._filter_boxes(
217+
image_size=(h, w), min_size=1, boxes=sample["boxes"], labels=None
218+
)
214219
else:
215220
sample["boxes"] = self._load_boxes(path)
216221
sample["label"] = self._load_target(path)
217222

223+
h, w = sample["image"].shape[1:]
224+
sample["boxes"], sample["label"] = self._filter_boxes(
225+
image_size=(h, w),
226+
min_size=1,
227+
boxes=sample["boxes"],
228+
labels=sample["label"],
229+
)
230+
218231
if self.transforms is not None:
219232
sample = self.transforms(sample)
220233

@@ -271,11 +284,15 @@ def _load_boxes(self, path: str) -> Tensor:
271284
geometries = cast(Dict[int, Dict[str, Any]], self.geometries)
272285

273286
# Find object ids and geometries
287+
# The train set geometry->image mapping is contained
288+
# in the train/Field/itc_rsFile.csv file
274289
if self.split == "train":
275290
indices = self.labels["rsFile"] == base_path
276291
ids = self.labels[indices]["id"].tolist()
277292
geoms = [geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
278-
# Test set - Task 2 has no mapping csv. Mapping is inside of geometry
293+
# The test set has no mapping csv. The mapping is inside of the geometry
294+
# properties i.e. geom["property"]["plotID"] contains the RGB image filename
295+
# Return all geometries with the matching RGB image filename of the sample
279296
else:
280297
ids = [
281298
k
@@ -380,17 +397,59 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
380397
"""
381398
filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))
382399

400+
i = 0
383401
features: Dict[int, Dict[str, Any]] = {}
384402
for path in filepaths:
385403
with fiona.open(path) as src:
386-
for i, feature in enumerate(src):
404+
for feature in src:
405+
# The train set has a unique id for each geometry in the properties
387406
if self.split == "train":
388407
features[feature["properties"]["id"]] = feature
389-
# Test set task 2 has no id
408+
# The test set has no unique id so create a dummy id
390409
else:
391410
features[i] = feature
411+
i += 1
392412
return features
393413

414+
@overload
415+
def _filter_boxes(
416+
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor
417+
) -> Tuple[Tensor, Tensor]:
418+
...
419+
420+
@overload
421+
def _filter_boxes(
422+
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: None
423+
) -> Tuple[Tensor, None]:
424+
...
425+
426+
def _filter_boxes(
427+
self,
428+
image_size: Tuple[int, int],
429+
min_size: int,
430+
boxes: Tensor,
431+
labels: Optional[Tensor],
432+
) -> Tuple[Tensor, Optional[Tensor]]:
433+
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.
434+
435+
Args:
436+
image_size: tuple of (height, width) of image
437+
min_size: filter boxes that have any side less than min_size
438+
boxes: [N, 4] shape tensor of xyxy bounding box coordinates
439+
labels: (Optional) [N,] shape tensor of bounding box labels
440+
441+
Returns:
442+
a tuple of filtered boxes and labels
443+
"""
444+
boxes = clip_boxes_to_image(boxes=boxes, size=image_size)
445+
indices = remove_small_boxes(boxes=boxes, min_size=min_size)
446+
447+
boxes = boxes[indices]
448+
if labels is not None:
449+
labels = labels[indices]
450+
451+
return boxes, labels
452+
394453
def _verify(self) -> None:
395454
"""Verify the integrity of the dataset.
396455

0 commit comments

Comments
 (0)