Skip to content

Commit 37820ef

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] port horizontal flip tests (#7703)
Reviewed By: vmoens Differential Revision: D47186579 fbshipit-source-id: 5077d10522cf36ba99aac3863cd55bb967eb8c89
1 parent a05a13b commit 37820ef

File tree

5 files changed

+166
-167
lines changed

5 files changed

+166
-167
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -406,59 +406,6 @@ def was_applied(output, inpt):
406406
assert transform.was_applied(output, input)
407407

408408

409-
@pytest.mark.parametrize("p", [0.0, 1.0])
410-
class TestRandomHorizontalFlip:
411-
def input_expected_image_tensor(self, p, dtype=torch.float32):
412-
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
413-
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
414-
415-
return input, expected if p == 1 else input
416-
417-
def test_simple_tensor(self, p):
418-
input, expected = self.input_expected_image_tensor(p)
419-
transform = transforms.RandomHorizontalFlip(p=p)
420-
421-
actual = transform(input)
422-
423-
assert_equal(expected, actual)
424-
425-
def test_pil_image(self, p):
426-
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
427-
transform = transforms.RandomHorizontalFlip(p=p)
428-
429-
actual = transform(to_pil_image(input))
430-
431-
assert_equal(expected, pil_to_tensor(actual))
432-
433-
def test_datapoints_image(self, p):
434-
input, expected = self.input_expected_image_tensor(p)
435-
transform = transforms.RandomHorizontalFlip(p=p)
436-
437-
actual = transform(datapoints.Image(input))
438-
439-
assert_equal(datapoints.Image(expected), actual)
440-
441-
def test_datapoints_mask(self, p):
442-
input, expected = self.input_expected_image_tensor(p)
443-
transform = transforms.RandomHorizontalFlip(p=p)
444-
445-
actual = transform(datapoints.Mask(input))
446-
447-
assert_equal(datapoints.Mask(expected), actual)
448-
449-
def test_datapoints_bounding_box(self, p):
450-
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
451-
transform = transforms.RandomHorizontalFlip(p=p)
452-
453-
actual = transform(input)
454-
455-
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
456-
expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
457-
assert_equal(expected, actual)
458-
assert actual.format == expected.format
459-
assert actual.spatial_size == expected.spatial_size
460-
461-
462409
@pytest.mark.parametrize("p", [0.0, 1.0])
463410
class TestRandomVerticalFlip:
464411
def input_expected_image_tensor(self, p, dtype=torch.float32):

test/test_transforms_v2_refactored.py

Lines changed: 154 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def check_transform(transform_cls, input, *args, **kwargs):
295295
_check_transform_v1_compatibility(transform, input)
296296

297297

298-
def transform_cls_to_functional(transform_cls):
298+
def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
299299
def wrapper(input, *args, **kwargs):
300-
transform = transform_cls(*args, **kwargs)
300+
transform = transform_cls(*args, **transform_specific_kwargs, **kwargs)
301301
return transform(input)
302302

303303
wrapper.__name__ = transform_cls.__name__
@@ -321,14 +321,14 @@ def assert_warns_antialias_default_value():
321321

322322

323323
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
324-
def transform(bbox, affine_matrix_, format_, spatial_size_):
324+
def transform(bbox):
325325
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
326326
in_dtype = bbox.dtype
327327
if not torch.is_floating_point(bbox):
328328
bbox = bbox.float()
329329
bbox_xyxy = F.convert_format_bounding_box(
330330
bbox.as_subclass(torch.Tensor),
331-
old_format=format_,
331+
old_format=format,
332332
new_format=datapoints.BoundingBoxFormat.XYXY,
333333
inplace=True,
334334
)
@@ -340,7 +340,7 @@ def transform(bbox, affine_matrix_, format_, spatial_size_):
340340
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
341341
]
342342
)
343-
transformed_points = np.matmul(points, affine_matrix_.T)
343+
transformed_points = np.matmul(points, affine_matrix.T)
344344
out_bbox = torch.tensor(
345345
[
346346
np.min(transformed_points[:, 0]).item(),
@@ -351,23 +351,14 @@ def transform(bbox, affine_matrix_, format_, spatial_size_):
351351
dtype=bbox_xyxy.dtype,
352352
)
353353
out_bbox = F.convert_format_bounding_box(
354-
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
354+
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
355355
)
356356
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
357-
out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_)
357+
out_bbox = F.clamp_bounding_box(out_bbox, format=format, spatial_size=spatial_size)
358358
out_bbox = out_bbox.to(dtype=in_dtype)
359359
return out_bbox
360360

361-
if bounding_box.ndim < 2:
362-
bounding_box = [bounding_box]
363-
364-
expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
365-
if len(expected_bboxes) > 1:
366-
expected_bboxes = torch.stack(expected_bboxes)
367-
else:
368-
expected_bboxes = expected_bboxes[0]
369-
370-
return expected_bboxes
361+
return torch.stack([transform(b) for b in bounding_box.reshape(-1, 4).unbind()]).reshape(bounding_box.shape)
371362

372363

373364
class TestResize:
@@ -493,7 +484,7 @@ def test_kernel_video(self):
493484

494485
@pytest.mark.parametrize("size", OUTPUT_SIZES)
495486
@pytest.mark.parametrize(
496-
"input_type_and_kernel",
487+
("input_type", "kernel"),
497488
[
498489
(torch.Tensor, F.resize_image_tensor),
499490
(PIL.Image.Image, F.resize_image_pil),
@@ -503,8 +494,7 @@ def test_kernel_video(self):
503494
(datapoints.Video, F.resize_video),
504495
],
505496
)
506-
def test_dispatcher(self, size, input_type_and_kernel):
507-
input_type, kernel = input_type_and_kernel
497+
def test_dispatcher(self, size, input_type, kernel):
508498
check_dispatcher(
509499
F.resize,
510500
kernel,
@@ -726,3 +716,147 @@ def test_no_regression_5405(self, input_type):
726716
output = F.resize(input, size=size, max_size=max_size, antialias=True)
727717

728718
assert max(F.get_spatial_size(output)) == max_size
719+
720+
721+
class TestHorizontalFlip:
722+
def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), **kwargs):
723+
if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}:
724+
input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
725+
if input_type is torch.Tensor:
726+
input = input.as_subclass(torch.Tensor)
727+
elif input_type is PIL.Image.Image:
728+
input = F.to_image_pil(input)
729+
elif input_type is datapoints.BoundingBox:
730+
kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY)
731+
input = make_bounding_box(
732+
dtype=dtype or torch.float32,
733+
device=device,
734+
spatial_size=spatial_size,
735+
**kwargs,
736+
)
737+
elif input_type is datapoints.Mask:
738+
input = make_segmentation_mask(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
739+
elif input_type is datapoints.Video:
740+
input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
741+
742+
return input
743+
744+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
745+
@pytest.mark.parametrize("device", cpu_and_cuda())
746+
def test_kernel_image_tensor(self, dtype, device):
747+
check_kernel(F.horizontal_flip_image_tensor, self._make_input(torch.Tensor))
748+
749+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
750+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
751+
@pytest.mark.parametrize("device", cpu_and_cuda())
752+
def test_kernel_bounding_box(self, format, dtype, device):
753+
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
754+
check_kernel(
755+
F.horizontal_flip_bounding_box,
756+
bounding_box,
757+
format=format,
758+
spatial_size=bounding_box.spatial_size,
759+
)
760+
761+
@pytest.mark.parametrize(
762+
"dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)]
763+
)
764+
def test_kernel_mask(self, dtype_and_make_mask):
765+
dtype, make_mask = dtype_and_make_mask
766+
check_kernel(F.horizontal_flip_mask, make_mask(dtype=dtype))
767+
768+
def test_kernel_video(self):
769+
check_kernel(F.horizontal_flip_video, self._make_input(datapoints.Video))
770+
771+
@pytest.mark.parametrize(
772+
("input_type", "kernel"),
773+
[
774+
(torch.Tensor, F.horizontal_flip_image_tensor),
775+
(PIL.Image.Image, F.horizontal_flip_image_pil),
776+
(datapoints.Image, F.horizontal_flip_image_tensor),
777+
(datapoints.BoundingBox, F.horizontal_flip_bounding_box),
778+
(datapoints.Mask, F.horizontal_flip_mask),
779+
(datapoints.Video, F.horizontal_flip_video),
780+
],
781+
)
782+
def test_dispatcher(self, kernel, input_type):
783+
check_dispatcher(F.horizontal_flip, kernel, self._make_input(input_type))
784+
785+
@pytest.mark.parametrize(
786+
("input_type", "kernel"),
787+
[
788+
(torch.Tensor, F.resize_image_tensor),
789+
(PIL.Image.Image, F.resize_image_pil),
790+
(datapoints.Image, F.resize_image_tensor),
791+
(datapoints.BoundingBox, F.resize_bounding_box),
792+
(datapoints.Mask, F.resize_mask),
793+
(datapoints.Video, F.resize_video),
794+
],
795+
)
796+
def test_dispatcher_signature(self, kernel, input_type):
797+
check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type)
798+
799+
@pytest.mark.parametrize(
800+
"input_type",
801+
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
802+
)
803+
@pytest.mark.parametrize("device", cpu_and_cuda())
804+
def test_transform(self, input_type, device):
805+
input = self._make_input(input_type, device=device)
806+
807+
check_transform(transforms.RandomHorizontalFlip, input, p=1)
808+
809+
@pytest.mark.parametrize(
810+
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
811+
)
812+
def test_image_correctness(self, fn):
813+
image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu")
814+
815+
actual = fn(image)
816+
expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image)))
817+
818+
torch.testing.assert_close(actual, expected)
819+
820+
def _reference_horizontal_flip_bounding_box(self, bounding_box):
821+
affine_matrix = np.array(
822+
[
823+
[-1, 0, bounding_box.spatial_size[1]],
824+
[0, 1, 0],
825+
],
826+
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
827+
)
828+
829+
expected_bboxes = reference_affine_bounding_box_helper(
830+
bounding_box,
831+
format=bounding_box.format,
832+
spatial_size=bounding_box.spatial_size,
833+
affine_matrix=affine_matrix,
834+
)
835+
836+
return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes)
837+
838+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
839+
@pytest.mark.parametrize(
840+
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
841+
)
842+
def test_bounding_box_correctness(self, format, fn):
843+
bounding_box = self._make_input(datapoints.BoundingBox)
844+
845+
actual = fn(bounding_box)
846+
expected = self._reference_horizontal_flip_bounding_box(bounding_box)
847+
848+
torch.testing.assert_close(actual, expected)
849+
850+
@pytest.mark.parametrize(
851+
"input_type",
852+
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
853+
)
854+
@pytest.mark.parametrize("device", cpu_and_cuda())
855+
def test_transform_noop(self, input_type, device):
856+
input = self._make_input(input_type, device=device)
857+
858+
transform = transforms.RandomHorizontalFlip(p=0)
859+
860+
output = transform(input)
861+
862+
assert_equal(output, input)

test/transforms_v2_dispatcher_infos.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
138138

139139

140140
DISPATCHER_INFOS = [
141-
DispatcherInfo(
142-
F.horizontal_flip,
143-
kernels={
144-
datapoints.Image: F.horizontal_flip_image_tensor,
145-
datapoints.Video: F.horizontal_flip_video,
146-
datapoints.BoundingBox: F.horizontal_flip_bounding_box,
147-
datapoints.Mask: F.horizontal_flip_mask,
148-
},
149-
pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
150-
),
151141
DispatcherInfo(
152142
F.affine,
153143
kernels={

0 commit comments

Comments
 (0)