Skip to content

Commit decfc97

Browse files
committed
filter boxes outside of bounds in idtrees
1 parent bf443c0 commit decfc97

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

torchgeo/datasets/idtrees.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torchvision.utils import draw_bounding_boxes
1818

1919
from .geo import NonGeoDataset
20-
from .utils import download_url, extract_archive
20+
from .utils import download_url, extract_archive, filter_boxes
2121

2222

2323
class IDTReeS(NonGeoDataset):
@@ -211,10 +211,20 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
211211
if self.split == "test":
212212
if self.task == "task2":
213213
sample["boxes"] = self._load_boxes(path)
214+
w, h = sample["image"].shape[1:]
215+
sample["boxes"], _ = filter_boxes(
216+
image_size=(h, w), boxes=sample["boxes"]
217+
)
214218
else:
215219
sample["boxes"] = self._load_boxes(path)
216220
sample["label"] = self._load_target(path)
217221

222+
w, h = sample["image"].shape[1:]
223+
sample["boxes"], sample["label"] = filter_boxes( # type:ignore[assignment]
224+
image_size=(h, w), boxes=sample["boxes"], labels=sample["label"]
225+
)
226+
227+
# Filter boxes
218228
if self.transforms is not None:
219229
sample = self.transforms(sample)
220230

@@ -296,7 +306,6 @@ def _load_boxes(self, path: str) -> Tensor:
296306
boxes.append([xmin, ymin, xmax, ymax])
297307

298308
tensor = torch.tensor(boxes)
299-
tensor = torch.clamp(tensor, min=0, max=self.image_size[0])
300309
return tensor
301310

302311
def _load_target(self, path: str) -> Tensor:

0 commit comments

Comments
 (0)