|
5 | 5 |
|
6 | 6 | import glob
|
7 | 7 | import os
|
8 |
| -from typing import Any, Callable, Dict, List, Optional, Tuple, cast |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload |
9 | 9 |
|
10 | 10 | import fiona
|
11 | 11 | import matplotlib.pyplot as plt
|
|
14 | 14 | import torch
|
15 | 15 | from rasterio.enums import Resampling
|
16 | 16 | from torch import Tensor
|
| 17 | +from torchvision.ops import clip_boxes_to_image, remove_small_boxes |
17 | 18 | from torchvision.utils import draw_bounding_boxes
|
18 | 19 |
|
19 | 20 | from .geo import NonGeoDataset
|
@@ -211,10 +212,22 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
211 | 212 | if self.split == "test":
|
212 | 213 | if self.task == "task2":
|
213 | 214 | sample["boxes"] = self._load_boxes(path)
|
| 215 | + h, w = sample["image"].shape[1:] |
| 216 | + sample["boxes"], _ = self._filter_boxes( |
| 217 | + image_size=(h, w), min_size=1, boxes=sample["boxes"], labels=None |
| 218 | + ) |
214 | 219 | else:
|
215 | 220 | sample["boxes"] = self._load_boxes(path)
|
216 | 221 | sample["label"] = self._load_target(path)
|
217 | 222 |
|
| 223 | + h, w = sample["image"].shape[1:] |
| 224 | + sample["boxes"], sample["label"] = self._filter_boxes( |
| 225 | + image_size=(h, w), |
| 226 | + min_size=1, |
| 227 | + boxes=sample["boxes"], |
| 228 | + labels=sample["label"], |
| 229 | + ) |
| 230 | + |
218 | 231 | if self.transforms is not None:
|
219 | 232 | sample = self.transforms(sample)
|
220 | 233 |
|
@@ -271,11 +284,15 @@ def _load_boxes(self, path: str) -> Tensor:
|
271 | 284 | geometries = cast(Dict[int, Dict[str, Any]], self.geometries)
|
272 | 285 |
|
273 | 286 | # Find object ids and geometries
|
| 287 | + # The train set geometry->image mapping is contained |
| 288 | + # in the train/Field/itc_rsFile.csv file |
274 | 289 | if self.split == "train":
|
275 | 290 | indices = self.labels["rsFile"] == base_path
|
276 | 291 | ids = self.labels[indices]["id"].tolist()
|
277 | 292 | geoms = [geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
|
278 |
| - # Test set - Task 2 has no mapping csv. Mapping is inside of geometry |
| 293 | + # The test set has no mapping csv. The mapping is inside of the geometry |
| 294 | + # properties i.e. geom["property"]["plotID"] contains the RGB image filename |
| 295 | + # Return all geometries with the matching RGB image filename of the sample |
279 | 296 | else:
|
280 | 297 | ids = [
|
281 | 298 | k
|
@@ -380,17 +397,59 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
|
380 | 397 | """
|
381 | 398 | filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))
|
382 | 399 |
|
| 400 | + i = 0 |
383 | 401 | features: Dict[int, Dict[str, Any]] = {}
|
384 | 402 | for path in filepaths:
|
385 | 403 | with fiona.open(path) as src:
|
386 |
| - for i, feature in enumerate(src): |
| 404 | + for feature in src: |
| 405 | + # The train set has a unique id for each geometry in the properties |
387 | 406 | if self.split == "train":
|
388 | 407 | features[feature["properties"]["id"]] = feature
|
389 |
| - # Test set task 2 has no id |
| 408 | + # The test set has no unique id so create a dummy id |
390 | 409 | else:
|
391 | 410 | features[i] = feature
|
| 411 | + i += 1 |
392 | 412 | return features
|
393 | 413 |
|
| 414 | + @overload |
| 415 | + def _filter_boxes( |
| 416 | + self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor |
| 417 | + ) -> Tuple[Tensor, Tensor]: |
| 418 | + ... |
| 419 | + |
| 420 | + @overload |
| 421 | + def _filter_boxes( |
| 422 | + self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: None |
| 423 | + ) -> Tuple[Tensor, None]: |
| 424 | + ... |
| 425 | + |
| 426 | + def _filter_boxes( |
| 427 | + self, |
| 428 | + image_size: Tuple[int, int], |
| 429 | + min_size: int, |
| 430 | + boxes: Tensor, |
| 431 | + labels: Optional[Tensor], |
| 432 | + ) -> Tuple[Tensor, Optional[Tensor]]: |
| 433 | + """Clip boxes to image size and filter boxes with sides less than ``min_size``. |
| 434 | +
|
| 435 | + Args: |
| 436 | + image_size: tuple of (height, width) of image |
| 437 | + min_size: filter boxes that have any side less than min_size |
| 438 | + boxes: [N, 4] shape tensor of xyxy bounding box coordinates |
| 439 | + labels: (Optional) [N,] shape tensor of bounding box labels |
| 440 | +
|
| 441 | + Returns: |
| 442 | + a tuple of filtered boxes and labels |
| 443 | + """ |
| 444 | + boxes = clip_boxes_to_image(boxes=boxes, size=image_size) |
| 445 | + indices = remove_small_boxes(boxes=boxes, min_size=min_size) |
| 446 | + |
| 447 | + boxes = boxes[indices] |
| 448 | + if labels is not None: |
| 449 | + labels = labels[indices] |
| 450 | + |
| 451 | + return boxes, labels |
| 452 | + |
394 | 453 | def _verify(self) -> None:
|
395 | 454 | """Verify the integrity of the dataset.
|
396 | 455 |
|
|
0 commit comments