|
17 | 17 | from torchvision.utils import draw_bounding_boxes
|
18 | 18 |
|
19 | 19 | from .geo import NonGeoDataset
|
20 |
| -from .utils import download_url, extract_archive |
| 20 | +from .utils import download_url, extract_archive, filter_boxes |
21 | 21 |
|
22 | 22 |
|
23 | 23 | class IDTReeS(NonGeoDataset):
|
@@ -211,10 +211,20 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
211 | 211 | if self.split == "test":
|
212 | 212 | if self.task == "task2":
|
213 | 213 | sample["boxes"] = self._load_boxes(path)
|
| 214 | + w, h = sample["image"].shape[1:] |
| 215 | + sample["boxes"], _ = filter_boxes( |
| 216 | + image_size=(h, w), boxes=sample["boxes"] |
| 217 | + ) |
214 | 218 | else:
|
215 | 219 | sample["boxes"] = self._load_boxes(path)
|
216 | 220 | sample["label"] = self._load_target(path)
|
217 | 221 |
|
| 222 | + w, h = sample["image"].shape[1:] |
| 223 | + sample["boxes"], sample["label"] = filter_boxes( # type:ignore[assignment] |
| 224 | + image_size=(h, w), boxes=sample["boxes"], labels=sample["label"] |
| 225 | + ) |
| 226 | + |
| 227 | + # Filter boxes |
218 | 228 | if self.transforms is not None:
|
219 | 229 | sample = self.transforms(sample)
|
220 | 230 |
|
@@ -296,7 +306,6 @@ def _load_boxes(self, path: str) -> Tensor:
|
296 | 306 | boxes.append([xmin, ymin, xmax, ymax])
|
297 | 307 |
|
298 | 308 | tensor = torch.tensor(boxes)
|
299 |
| - tensor = torch.clamp(tensor, min=0, max=self.image_size[0]) |
300 | 309 | return tensor
|
301 | 310 |
|
302 | 311 | def _load_target(self, path: str) -> Tensor:
|
|
0 commit comments