Skip to content

Add --backend and --use-v2 support for segmentation references #7743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2


def convert_coco_poly_to_mask(segmentations, height, width):
Expand Down Expand Up @@ -213,6 +212,8 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_m
ann_file = os.path.join(root, ann_file)

if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
Expand Down
24 changes: 15 additions & 9 deletions references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ def _has_valid_annotation(anno):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand All @@ -86,21 +81,32 @@ def _has_valid_annotation(anno):
return dataset


def get_coco(root, image_set, transforms):
def get_coco(root, image_set, transforms, use_v2=False):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]

transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# The 2 "Compose" below achieve the same thing: converting coco detection
# samples into segmentation-compatible samples. They just do it with
# slightly different implementations. We could refactor and unify, but
# keeping them separate helps keeping the v2 version clean
if use_v2:
import v2_extras
from torchvision.datasets import wrap_dataset_for_transforms_v2

transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
else:
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
Expand Down
113 changes: 90 additions & 23 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,106 @@
from collections import defaultdict

import torch
import transforms as T


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
import v2_extras

return torchvision.transforms.v2, torchvision.datapoints, v2_extras
else:
import transforms

return transforms, None, None


class SegmentationPresetTrain:
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
def __init__(
self,
*,
base_size,
crop_size,
hflip_prob=0.5,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
backend="pil",
use_v2=False,
):
T, datapoints, v2_extras = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]

trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend(
[
T.RandomCrop(crop_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
transforms += [T.RandomHorizontalFlip(hflip_prob)]

if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
Comment on lines +49 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we at some point drop v1, we should re-evaluate if the custom scheme actually has an effect on the performance or if we can just RandomCrop. Until then, I agree we should keep the implementation as close as possible.

transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]

transforms += [T.RandomCrop(crop_size)]

if backend == "pil":
transforms += [T.PILToTensor()]

if use_v2:
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
transforms += [
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
]
)
self.transforms = T.Compose(trans)
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]

transforms += [T.Normalize(mean=mean, std=std)]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class SegmentationPresetEval:
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose(
[
T.RandomResize(base_size, base_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
def __init__(
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False
):
T, _, _ = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

if use_v2:
transforms += [T.Resize(size=(base_size, base_size))]
else:
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)]

if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on our decision in #7340. I would drop the comment here, because we'll easily forget about it.

transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
38 changes: 26 additions & 12 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,30 @@
from torchvision.transforms import functional as F, InterpolationMode


def get_dataset(dir_path, name, image_set, transform):
def get_dataset(args, is_train):
def sbd(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)

def voc(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.VOCSegmentation(*args, **kwargs)

paths = {
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
"voc_aug": (dir_path, sbd, 21),
"coco": (dir_path, get_coco, 21),
"voc": (args.data_path, voc, 21),
"voc_aug": (args.data_path, sbd, 21),
"coco": (args.data_path, get_coco, 21),
}
p, ds_fn, num_classes = paths[name]
p, ds_fn, num_classes = paths[args.dataset]

ds = ds_fn(p, image_set=image_set, transforms=transform)
image_set = "train" if is_train else "val"
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
return ds, num_classes


def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
def get_transform(is_train, args):
if is_train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
Expand All @@ -44,7 +50,7 @@ def preprocessing(img, target):

return preprocessing
else:
return presets.SegmentationPresetEval(base_size=520)
return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)


def criterion(inputs, target):
Expand Down Expand Up @@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi


def main(args):
if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.")
if args.use_v2 and args.dataset != "coco":
raise ValueError("v2 is only support supported for coco dataset for now.")

if args.output_dir:
utils.mkdir(args.output_dir)

Expand All @@ -134,8 +146,8 @@ def main(args):
else:
torch.backends.cudnn.benchmark = True

dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
dataset, num_classes = get_dataset(args, is_train=True)
dataset_test, _ = get_dataset(args, is_train=False)

if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
Expand Down Expand Up @@ -307,6 +319,8 @@ def get_args_parser(add_help=True):
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser


Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, min_size, max_size=None):

def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
image = F.resize(image, size, antialias=True)
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
return image, target

Expand Down
6 changes: 3 additions & 3 deletions references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ def init_distributed_mode(args):
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
Comment on lines +270 to +272
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Suggested change
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just to be able to run those scripts without invoking torchrun from the cluster (it's a bit of a mess). I'll remove it before merging

elif hasattr(args, "rank"):
pass
else:
Expand Down
83 changes: 83 additions & 0 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1."""
import torch
from torchvision import datapoints
from torchvision.transforms import v2


class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = v2._utils._convert_fill_arg(fill)

return v2.functional.pad(inpt, padding=params["padding"], fill=fill)


class CocoDetectionToVOCSegmentation(v2.Transform):
"""Turn samples from datasets.CocoDetection into the same format as VOCSegmentation.

This is achieved in two steps:

1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately,
the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not
present in VOC are dropped and replaced by background.
2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual
mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where
the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation
mask while pixels that belong to multiple detection masks are marked as invalid.
"""

COCO_TO_VOC_LABEL_MAP = dict(
zip(
[0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72],
range(21),
)
)
INVALID_VALUE = 255

def _coco_detection_masks_to_voc_segmentation_mask(self, target):
if "masks" not in target:
return None

instance_masks, instance_labels_coco = target["masks"], target["labels"]

valid_labels_voc = [
(idx, label_voc)
for idx, label_coco in enumerate(instance_labels_coco.tolist())
if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None
]

if not valid_labels_voc:
return None

valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc)

instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8)
instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8)

# Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as
# there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step.
segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0)
segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE

return segmentation_mask

def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)

return image, datapoints.Mask(segmentation_mask)