Skip to content

Commit 030d27a

Browse files
authored
Merge branch 'main' into fix-make-image
2 parents de54057 + b9b7cfc commit 030d27a

File tree

7 files changed

+220
-49
lines changed

7 files changed

+220
-49
lines changed

references/detection/coco_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import transforms as T
77
from pycocotools import mask as coco_mask
88
from pycocotools.coco import COCO
9-
from torchvision.datasets import wrap_dataset_for_transforms_v2
109

1110

1211
def convert_coco_poly_to_mask(segmentations, height, width):
@@ -213,6 +212,8 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_m
213212
ann_file = os.path.join(root, ann_file)
214213

215214
if use_v2:
215+
from torchvision.datasets import wrap_dataset_for_transforms_v2
216+
216217
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
217218
target_keys = ["boxes", "labels", "image_id"]
218219
if with_masks:

references/segmentation/coco_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ def _has_valid_annotation(anno):
6868
# if more than 1k pixels occupied in the image
6969
return sum(obj["area"] for obj in anno) > 1000
7070

71-
if not isinstance(dataset, torchvision.datasets.CocoDetection):
72-
raise TypeError(
73-
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
74-
)
75-
7671
ids = []
7772
for ds_idx, img_id in enumerate(dataset.ids):
7873
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
@@ -86,21 +81,32 @@ def _has_valid_annotation(anno):
8681
return dataset
8782

8883

89-
def get_coco(root, image_set, transforms):
84+
def get_coco(root, image_set, transforms, use_v2=False):
9085
PATHS = {
9186
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
9287
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
9388
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
9489
}
9590
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
9691

97-
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
98-
9992
img_folder, ann_file = PATHS[image_set]
10093
img_folder = os.path.join(root, img_folder)
10194
ann_file = os.path.join(root, ann_file)
10295

103-
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
96+
# The 2 "Compose" below achieve the same thing: converting coco detection
97+
# samples into segmentation-compatible samples. They just do it with
98+
# slightly different implementations. We could refactor and unify, but
99+
# keeping them separate helps keeping the v2 version clean
100+
if use_v2:
101+
import v2_extras
102+
from torchvision.datasets import wrap_dataset_for_transforms_v2
103+
104+
transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
105+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
106+
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
107+
else:
108+
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
109+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
104110

105111
if image_set == "train":
106112
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)

references/segmentation/presets.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,106 @@
1+
from collections import defaultdict
2+
13
import torch
2-
import transforms as T
4+
5+
6+
def get_modules(use_v2):
7+
# We need a protected import to avoid the V2 warning in case just V1 is used
8+
if use_v2:
9+
import torchvision.datapoints
10+
import torchvision.transforms.v2
11+
import v2_extras
12+
13+
return torchvision.transforms.v2, torchvision.datapoints, v2_extras
14+
else:
15+
import transforms
16+
17+
return transforms, None, None
318

419

520
class SegmentationPresetTrain:
6-
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
7-
min_size = int(0.5 * base_size)
8-
max_size = int(2.0 * base_size)
21+
def __init__(
22+
self,
23+
*,
24+
base_size,
25+
crop_size,
26+
hflip_prob=0.5,
27+
mean=(0.485, 0.456, 0.406),
28+
std=(0.229, 0.224, 0.225),
29+
backend="pil",
30+
use_v2=False,
31+
):
32+
T, datapoints, v2_extras = get_modules(use_v2)
33+
34+
transforms = []
35+
backend = backend.lower()
36+
if backend == "datapoint":
37+
transforms.append(T.ToImageTensor())
38+
elif backend == "tensor":
39+
transforms.append(T.PILToTensor())
40+
elif backend != "pil":
41+
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
42+
43+
transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]
944

10-
trans = [T.RandomResize(min_size, max_size)]
1145
if hflip_prob > 0:
12-
trans.append(T.RandomHorizontalFlip(hflip_prob))
13-
trans.extend(
14-
[
15-
T.RandomCrop(crop_size),
16-
T.PILToTensor(),
17-
T.ConvertImageDtype(torch.float),
18-
T.Normalize(mean=mean, std=std),
46+
transforms += [T.RandomHorizontalFlip(hflip_prob)]
47+
48+
if use_v2:
49+
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
50+
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
51+
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
52+
53+
transforms += [T.RandomCrop(crop_size)]
54+
55+
if backend == "pil":
56+
transforms += [T.PILToTensor()]
57+
58+
if use_v2:
59+
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
60+
transforms += [
61+
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
1962
]
20-
)
21-
self.transforms = T.Compose(trans)
63+
else:
64+
# No need to explicitly convert masks as they're magically int64 already
65+
transforms += [T.ConvertImageDtype(torch.float)]
66+
67+
transforms += [T.Normalize(mean=mean, std=std)]
68+
69+
self.transforms = T.Compose(transforms)
2270

2371
def __call__(self, img, target):
2472
return self.transforms(img, target)
2573

2674

2775
class SegmentationPresetEval:
28-
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
29-
self.transforms = T.Compose(
30-
[
31-
T.RandomResize(base_size, base_size),
32-
T.PILToTensor(),
33-
T.ConvertImageDtype(torch.float),
34-
T.Normalize(mean=mean, std=std),
35-
]
36-
)
76+
def __init__(
77+
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False
78+
):
79+
T, _, _ = get_modules(use_v2)
80+
81+
transforms = []
82+
backend = backend.lower()
83+
if backend == "tensor":
84+
transforms += [T.PILToTensor()]
85+
elif backend == "datapoint":
86+
transforms += [T.ToImageTensor()]
87+
elif backend != "pil":
88+
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
89+
90+
if use_v2:
91+
transforms += [T.Resize(size=(base_size, base_size))]
92+
else:
93+
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)]
94+
95+
if backend == "pil":
96+
# Note: we could just convert to pure tensors even in v2?
97+
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
98+
99+
transforms += [
100+
T.ConvertImageDtype(torch.float),
101+
T.Normalize(mean=mean, std=std),
102+
]
103+
self.transforms = T.Compose(transforms)
37104

38105
def __call__(self, img, target):
39106
return self.transforms(img, target)

references/segmentation/train.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,30 @@
1414
from torchvision.transforms import functional as F, InterpolationMode
1515

1616

17-
def get_dataset(dir_path, name, image_set, transform):
17+
def get_dataset(args, is_train):
1818
def sbd(*args, **kwargs):
19+
kwargs.pop("use_v2")
1920
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
2021

22+
def voc(*args, **kwargs):
23+
kwargs.pop("use_v2")
24+
return torchvision.datasets.VOCSegmentation(*args, **kwargs)
25+
2126
paths = {
22-
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
23-
"voc_aug": (dir_path, sbd, 21),
24-
"coco": (dir_path, get_coco, 21),
27+
"voc": (args.data_path, voc, 21),
28+
"voc_aug": (args.data_path, sbd, 21),
29+
"coco": (args.data_path, get_coco, 21),
2530
}
26-
p, ds_fn, num_classes = paths[name]
31+
p, ds_fn, num_classes = paths[args.dataset]
2732

28-
ds = ds_fn(p, image_set=image_set, transforms=transform)
33+
image_set = "train" if is_train else "val"
34+
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
2935
return ds, num_classes
3036

3137

32-
def get_transform(train, args):
33-
if train:
34-
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
38+
def get_transform(is_train, args):
39+
if is_train:
40+
return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2)
3541
elif args.weights and args.test_only:
3642
weights = torchvision.models.get_weight(args.weights)
3743
trans = weights.transforms()
@@ -44,7 +50,7 @@ def preprocessing(img, target):
4450

4551
return preprocessing
4652
else:
47-
return presets.SegmentationPresetEval(base_size=520)
53+
return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
4854

4955

5056
def criterion(inputs, target):
@@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
120126

121127

122128
def main(args):
129+
if args.backend.lower() != "pil" and not args.use_v2:
130+
# TODO: Support tensor backend in V1?
131+
raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.")
132+
if args.use_v2 and args.dataset != "coco":
133+
raise ValueError("v2 is only support supported for coco dataset for now.")
134+
123135
if args.output_dir:
124136
utils.mkdir(args.output_dir)
125137

@@ -134,8 +146,8 @@ def main(args):
134146
else:
135147
torch.backends.cudnn.benchmark = True
136148

137-
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
138-
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
149+
dataset, num_classes = get_dataset(args, is_train=True)
150+
dataset_test, _ = get_dataset(args, is_train=False)
139151

140152
if args.distributed:
141153
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
@@ -307,6 +319,8 @@ def get_args_parser(add_help=True):
307319
# Mixed precision training parameters
308320
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
309321

322+
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
323+
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
310324
return parser
311325

312326

references/segmentation/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, min_size, max_size=None):
3535

3636
def __call__(self, image, target):
3737
size = random.randint(self.min_size, self.max_size)
38-
image = F.resize(image, size)
38+
image = F.resize(image, size, antialias=True)
3939
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
4040
return image, target
4141

references/segmentation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ def init_distributed_mode(args):
267267
args.rank = int(os.environ["RANK"])
268268
args.world_size = int(os.environ["WORLD_SIZE"])
269269
args.gpu = int(os.environ["LOCAL_RANK"])
270-
elif "SLURM_PROCID" in os.environ:
271-
args.rank = int(os.environ["SLURM_PROCID"])
272-
args.gpu = args.rank % torch.cuda.device_count()
270+
# elif "SLURM_PROCID" in os.environ:
271+
# args.rank = int(os.environ["SLURM_PROCID"])
272+
# args.gpu = args.rank % torch.cuda.device_count()
273273
elif hasattr(args, "rank"):
274274
pass
275275
else:

references/segmentation/v2_extras.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1."""
2+
import torch
3+
from torchvision import datapoints
4+
from torchvision.transforms import v2
5+
6+
7+
class PadIfSmaller(v2.Transform):
8+
def __init__(self, size, fill=0):
9+
super().__init__()
10+
self.size = size
11+
self.fill = v2._geometry._setup_fill_arg(fill)
12+
13+
def _get_params(self, sample):
14+
_, height, width = v2.utils.query_chw(sample)
15+
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
16+
needs_padding = any(padding)
17+
return dict(padding=padding, needs_padding=needs_padding)
18+
19+
def _transform(self, inpt, params):
20+
if not params["needs_padding"]:
21+
return inpt
22+
23+
fill = self.fill[type(inpt)]
24+
fill = v2._utils._convert_fill_arg(fill)
25+
26+
return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
27+
28+
29+
class CocoDetectionToVOCSegmentation(v2.Transform):
30+
"""Turn samples from datasets.CocoDetection into the same format as VOCSegmentation.
31+
32+
This is achieved in two steps:
33+
34+
1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately,
35+
the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not
36+
present in VOC are dropped and replaced by background.
37+
2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual
38+
mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where
39+
the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation
40+
mask while pixels that belong to multiple detection masks are marked as invalid.
41+
"""
42+
43+
COCO_TO_VOC_LABEL_MAP = dict(
44+
zip(
45+
[0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72],
46+
range(21),
47+
)
48+
)
49+
INVALID_VALUE = 255
50+
51+
def _coco_detection_masks_to_voc_segmentation_mask(self, target):
52+
if "masks" not in target:
53+
return None
54+
55+
instance_masks, instance_labels_coco = target["masks"], target["labels"]
56+
57+
valid_labels_voc = [
58+
(idx, label_voc)
59+
for idx, label_coco in enumerate(instance_labels_coco.tolist())
60+
if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None
61+
]
62+
63+
if not valid_labels_voc:
64+
return None
65+
66+
valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc)
67+
68+
instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8)
69+
instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8)
70+
71+
# Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as
72+
# there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step.
73+
segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0)
74+
segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE
75+
76+
return segmentation_mask
77+
78+
def forward(self, image, target):
79+
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
80+
if segmentation_mask is None:
81+
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)
82+
83+
return image, datapoints.Mask(segmentation_mask)

0 commit comments

Comments
 (0)