1
- import copy
2
1
import os
3
2
4
3
import torch
10
9
from torchvision .datasets import wrap_dataset_for_transforms_v2
11
10
12
11
13
- class FilterAndRemapCocoCategories :
14
- def __init__ (self , categories , remap = True ):
15
- self .categories = categories
16
- self .remap = remap
17
-
18
- def __call__ (self , image , target ):
19
- anno = target ["annotations" ]
20
- anno = [obj for obj in anno if obj ["category_id" ] in self .categories ]
21
- if not self .remap :
22
- target ["annotations" ] = anno
23
- return image , target
24
- anno = copy .deepcopy (anno )
25
- for obj in anno :
26
- obj ["category_id" ] = self .categories .index (obj ["category_id" ])
27
- target ["annotations" ] = anno
28
- return image , target
29
-
30
-
31
12
def convert_coco_poly_to_mask (segmentations , height , width ):
32
13
masks = []
33
14
for polygons in segmentations :
@@ -219,7 +200,7 @@ def __getitem__(self, idx):
219
200
return img , target
220
201
221
202
222
- def get_coco (root , image_set , transforms , mode = "instances" , use_v2 = False ):
203
+ def get_coco (root , image_set , transforms , mode = "instances" , use_v2 = False , with_masks = False ):
223
204
anno_file_template = "{}_{}2017.json"
224
205
PATHS = {
225
206
"train" : ("train2017" , os .path .join ("annotations" , anno_file_template .format (mode , "train" ))),
@@ -233,9 +214,12 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
233
214
234
215
if use_v2 :
235
216
dataset = torchvision .datasets .CocoDetection (img_folder , ann_file , transforms = transforms )
236
- # TODO: need to update target_keys to handle masks for segmentation!
237
- dataset = wrap_dataset_for_transforms_v2 (dataset , target_keys = {"boxes" , "labels" , "image_id" })
217
+ target_keys = ["boxes" , "labels" , "image_id" ]
218
+ if with_masks :
219
+ target_keys += ["masks" ]
220
+ dataset = wrap_dataset_for_transforms_v2 (dataset , target_keys = target_keys )
238
221
else :
222
+ # TODO: handle with_masks for V1?
239
223
t = [ConvertCocoPolysToMask ()]
240
224
if transforms is not None :
241
225
t .append (transforms )
@@ -249,9 +233,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
249
233
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
250
234
251
235
return dataset
252
-
253
-
254
- def get_coco_kp (root , image_set , transforms , use_v2 = False ):
255
- if use_v2 :
256
- raise ValueError ("KeyPoints aren't supported by transforms V2 yet." )
257
- return get_coco (root , image_set , transforms , mode = "person_keypoints" )
0 commit comments