Skip to content

Commit de54057

Browse files
committed
Merge branch 'main' into fix-make-image
2 parents 955c596 + 8233c9c commit de54057

17 files changed

+325
-305
lines changed

docs/source/transforms.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ Conversion
234234
v2.PILToTensor
235235
v2.ToImageTensor
236236
ConvertImageDtype
237-
v2.ConvertDtype
238237
v2.ConvertImageDtype
239238
v2.ToDtype
240239
v2.ConvertBoundingBoxFormat

gallery/plot_transforms_v2_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def show(sample):
2929
image, target = sample
3030
if isinstance(image, PIL.Image.Image):
3131
image = F.to_image_tensor(image)
32-
image = F.convert_dtype(image, torch.uint8)
32+
image = F.to_dtype(image, torch.uint8, scale=True)
3333
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
3434

3535
fig, ax = plt.subplots()

references/detection/coco_utils.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import os
32

43
import torch
@@ -10,24 +9,6 @@
109
from torchvision.datasets import wrap_dataset_for_transforms_v2
1110

1211

13-
class FilterAndRemapCocoCategories:
14-
def __init__(self, categories, remap=True):
15-
self.categories = categories
16-
self.remap = remap
17-
18-
def __call__(self, image, target):
19-
anno = target["annotations"]
20-
anno = [obj for obj in anno if obj["category_id"] in self.categories]
21-
if not self.remap:
22-
target["annotations"] = anno
23-
return image, target
24-
anno = copy.deepcopy(anno)
25-
for obj in anno:
26-
obj["category_id"] = self.categories.index(obj["category_id"])
27-
target["annotations"] = anno
28-
return image, target
29-
30-
3112
def convert_coco_poly_to_mask(segmentations, height, width):
3213
masks = []
3314
for polygons in segmentations:
@@ -219,7 +200,7 @@ def __getitem__(self, idx):
219200
return img, target
220201

221202

222-
def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
203+
def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
223204
anno_file_template = "{}_{}2017.json"
224205
PATHS = {
225206
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
@@ -233,9 +214,12 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
233214

234215
if use_v2:
235216
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
236-
# TODO: need to update target_keys to handle masks for segmentation!
237-
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
217+
target_keys = ["boxes", "labels", "image_id"]
218+
if with_masks:
219+
target_keys += ["masks"]
220+
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
238221
else:
222+
# TODO: handle with_masks for V1?
239223
t = [ConvertCocoPolysToMask()]
240224
if transforms is not None:
241225
t.append(transforms)
@@ -249,9 +233,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
249233
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
250234

251235
return dataset
252-
253-
254-
def get_coco_kp(root, image_set, transforms, use_v2=False):
255-
if use_v2:
256-
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
257-
return get_coco(root, image_set, transforms, mode="person_keypoints")

references/detection/train.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torchvision.models.detection
2929
import torchvision.models.detection.mask_rcnn
3030
import utils
31-
from coco_utils import get_coco, get_coco_kp
31+
from coco_utils import get_coco
3232
from engine import evaluate, train_one_epoch
3333
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
3434
from torchvision.transforms import InterpolationMode
@@ -42,10 +42,16 @@ def copypaste_collate_fn(batch):
4242

4343
def get_dataset(is_train, args):
4444
image_set = "train" if is_train else "val"
45-
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
46-
p, ds_fn, num_classes = paths[args.dataset]
47-
48-
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
45+
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
46+
with_masks = "mask" in args.model
47+
ds = get_coco(
48+
root=args.data_path,
49+
image_set=image_set,
50+
transforms=get_transform(is_train, args),
51+
mode=mode,
52+
use_v2=args.use_v2,
53+
with_masks=with_masks,
54+
)
4955
return ds, num_classes
5056

5157

@@ -68,7 +74,12 @@ def get_args_parser(add_help=True):
6874
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
6975

7076
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
71-
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
77+
parser.add_argument(
78+
"--dataset",
79+
default="coco",
80+
type=str,
81+
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
82+
)
7283
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
7384
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
7485
parser.add_argument(
@@ -171,6 +182,12 @@ def get_args_parser(add_help=True):
171182
def main(args):
172183
if args.backend.lower() == "datapoint" and not args.use_v2:
173184
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
185+
if args.dataset not in ("coco", "coco_kp"):
186+
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
187+
if "keypoint" in args.model and args.dataset != "coco_kp":
188+
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
189+
if args.dataset == "coco_kp" and args.use_v2:
190+
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
174191

175192
if args.output_dir:
176193
utils.mkdir(args.output_dir)

test/common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2828
from torchvision import datapoints, io
2929
from torchvision.transforms._functional_tensor import _max_value as get_max_value
30-
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_pil, to_image_tensor
30+
from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor
3131

3232

3333
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -602,7 +602,7 @@ def fn(shape, dtype, device, memory_format):
602602
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
603603
else:
604604
image_tensor = image_tensor.to(device=device)
605-
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)
605+
image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True)
606606

607607
return datapoints.Image(image_tensor)
608608

test/test_transforms_v2.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import itertools
22
import pathlib
33
import random
4-
import re
54
import textwrap
65
import warnings
76
from collections import defaultdict
@@ -105,7 +104,7 @@ def normalize_adapter(transform, input, device):
105104
continue
106105
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
107106
# normalize doesn't support integer images
108-
value = F.convert_dtype(value, torch.float32)
107+
value = F.to_dtype(value, torch.float32, scale=True)
109108
adapted_input[key] = value
110109
return adapted_input
111110

@@ -146,7 +145,7 @@ class TestSmoke:
146145
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
147146
(transforms.ClampBoundingBox(), None),
148147
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
149-
(transforms.ConvertDtype(), None),
148+
(transforms.ConvertImageDtype(), None),
150149
(transforms.GaussianBlur(kernel_size=3), None),
151150
(
152151
transforms.LinearTransformation(
@@ -1326,61 +1325,6 @@ def test__transform(self, mocker):
13261325
)
13271326

13281327

1329-
class TestToDtype:
1330-
@pytest.mark.parametrize(
1331-
("dtype", "expected_dtypes"),
1332-
[
1333-
(
1334-
torch.float64,
1335-
{
1336-
datapoints.Video: torch.float64,
1337-
datapoints.Image: torch.float64,
1338-
datapoints.BoundingBox: torch.float64,
1339-
},
1340-
),
1341-
(
1342-
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
1343-
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
1344-
),
1345-
],
1346-
)
1347-
def test_call(self, dtype, expected_dtypes):
1348-
sample = dict(
1349-
video=make_video(dtype=torch.int64),
1350-
image=make_image(dtype=torch.uint8),
1351-
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
1352-
str="str",
1353-
int=0,
1354-
)
1355-
1356-
transform = transforms.ToDtype(dtype)
1357-
transformed_sample = transform(sample)
1358-
1359-
for key, value in sample.items():
1360-
value_type = type(value)
1361-
transformed_value = transformed_sample[key]
1362-
1363-
# make sure the transformation retains the type
1364-
assert isinstance(transformed_value, value_type)
1365-
1366-
if isinstance(value, torch.Tensor):
1367-
assert transformed_value.dtype is expected_dtypes[value_type]
1368-
else:
1369-
assert transformed_value is value
1370-
1371-
@pytest.mark.filterwarnings("error")
1372-
def test_plain_tensor_call(self):
1373-
tensor = torch.empty((), dtype=torch.float32)
1374-
transform = transforms.ToDtype({torch.Tensor: torch.float64})
1375-
1376-
assert transform(tensor).dtype is torch.float64
1377-
1378-
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
1379-
def test_plain_tensor_warning(self, other_type):
1380-
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
1381-
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})
1382-
1383-
13841328
class TestUniformTemporalSubsample:
13851329
@pytest.mark.parametrize(
13861330
"inpt",

test/test_transforms_v2_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __init__(
191191
closeness_kwargs=dict(rtol=None, atol=None),
192192
),
193193
ConsistencyConfig(
194-
v2_transforms.ConvertDtype,
194+
v2_transforms.ConvertImageDtype,
195195
legacy_transforms.ConvertImageDtype,
196196
[
197197
ArgsKwargs(torch.float16),

test/test_transforms_v2_functional.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,12 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs):
283283
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
284284

285285
actual = info.kernel(
286-
F.convert_dtype_image_tensor(input, dtype=torch.float32),
286+
F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
287287
*adapted_other_args,
288288
**adapted_kwargs,
289289
)
290290

291-
expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)
291+
expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
292292

293293
assert_close(
294294
actual,
@@ -538,7 +538,6 @@ def test_bounding_box_format_consistency(self, info, args_kwargs):
538538
(F.get_image_num_channels, F.get_num_channels),
539539
(F.to_pil_image, F.to_image_pil),
540540
(F.elastic_transform, F.elastic),
541-
(F.convert_image_dtype, F.convert_dtype_image_tensor),
542541
(F.to_grayscale, F.rgb_to_grayscale),
543542
]
544543
],
@@ -547,24 +546,6 @@ def test_alias(alias, target):
547546
assert alias is target
548547

549548

550-
@pytest.mark.parametrize(
551-
("info", "args_kwargs"),
552-
make_info_args_kwargs_params(
553-
KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
554-
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
555-
),
556-
)
557-
@pytest.mark.parametrize("device", cpu_and_cuda())
558-
def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
559-
(input, *other_args), kwargs = args_kwargs.load(device)
560-
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
561-
562-
output = info.kernel(input, dtype)
563-
564-
assert output.dtype == dtype
565-
assert output.device == input.device
566-
567-
568549
@pytest.mark.parametrize("device", cpu_and_cuda())
569550
@pytest.mark.parametrize("num_channels", [1, 3])
570551
def test_normalize_image_tensor_stats(device, num_channels):

0 commit comments

Comments
 (0)