@@ -295,9 +295,9 @@ def check_transform(transform_cls, input, *args, **kwargs):
295
295
_check_transform_v1_compatibility (transform , input )
296
296
297
297
298
- def transform_cls_to_functional (transform_cls ):
298
+ def transform_cls_to_functional (transform_cls , ** transform_specific_kwargs ):
299
299
def wrapper (input , * args , ** kwargs ):
300
- transform = transform_cls (* args , ** kwargs )
300
+ transform = transform_cls (* args , ** transform_specific_kwargs , ** kwargs )
301
301
return transform (input )
302
302
303
303
wrapper .__name__ = transform_cls .__name__
@@ -321,14 +321,14 @@ def assert_warns_antialias_default_value():
321
321
322
322
323
323
def reference_affine_bounding_box_helper (bounding_box , * , format , spatial_size , affine_matrix ):
324
- def transform (bbox , affine_matrix_ , format_ , spatial_size_ ):
324
+ def transform (bbox ):
325
325
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
326
326
in_dtype = bbox .dtype
327
327
if not torch .is_floating_point (bbox ):
328
328
bbox = bbox .float ()
329
329
bbox_xyxy = F .convert_format_bounding_box (
330
330
bbox .as_subclass (torch .Tensor ),
331
- old_format = format_ ,
331
+ old_format = format ,
332
332
new_format = datapoints .BoundingBoxFormat .XYXY ,
333
333
inplace = True ,
334
334
)
@@ -340,7 +340,7 @@ def transform(bbox, affine_matrix_, format_, spatial_size_):
340
340
[bbox_xyxy [2 ].item (), bbox_xyxy [3 ].item (), 1.0 ],
341
341
]
342
342
)
343
- transformed_points = np .matmul (points , affine_matrix_ .T )
343
+ transformed_points = np .matmul (points , affine_matrix .T )
344
344
out_bbox = torch .tensor (
345
345
[
346
346
np .min (transformed_points [:, 0 ]).item (),
@@ -351,23 +351,14 @@ def transform(bbox, affine_matrix_, format_, spatial_size_):
351
351
dtype = bbox_xyxy .dtype ,
352
352
)
353
353
out_bbox = F .convert_format_bounding_box (
354
- out_bbox , old_format = datapoints .BoundingBoxFormat .XYXY , new_format = format_ , inplace = True
354
+ out_bbox , old_format = datapoints .BoundingBoxFormat .XYXY , new_format = format , inplace = True
355
355
)
356
356
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
357
- out_bbox = F .clamp_bounding_box (out_bbox , format = format_ , spatial_size = spatial_size_ )
357
+ out_bbox = F .clamp_bounding_box (out_bbox , format = format , spatial_size = spatial_size )
358
358
out_bbox = out_bbox .to (dtype = in_dtype )
359
359
return out_bbox
360
360
361
- if bounding_box .ndim < 2 :
362
- bounding_box = [bounding_box ]
363
-
364
- expected_bboxes = [transform (bbox , affine_matrix , format , spatial_size ) for bbox in bounding_box ]
365
- if len (expected_bboxes ) > 1 :
366
- expected_bboxes = torch .stack (expected_bboxes )
367
- else :
368
- expected_bboxes = expected_bboxes [0 ]
369
-
370
- return expected_bboxes
361
+ return torch .stack ([transform (b ) for b in bounding_box .reshape (- 1 , 4 ).unbind ()]).reshape (bounding_box .shape )
371
362
372
363
373
364
class TestResize :
@@ -493,7 +484,7 @@ def test_kernel_video(self):
493
484
494
485
@pytest .mark .parametrize ("size" , OUTPUT_SIZES )
495
486
@pytest .mark .parametrize (
496
- "input_type_and_kernel" ,
487
+ ( "input_type" , "kernel" ) ,
497
488
[
498
489
(torch .Tensor , F .resize_image_tensor ),
499
490
(PIL .Image .Image , F .resize_image_pil ),
@@ -503,8 +494,7 @@ def test_kernel_video(self):
503
494
(datapoints .Video , F .resize_video ),
504
495
],
505
496
)
506
- def test_dispatcher (self , size , input_type_and_kernel ):
507
- input_type , kernel = input_type_and_kernel
497
+ def test_dispatcher (self , size , input_type , kernel ):
508
498
check_dispatcher (
509
499
F .resize ,
510
500
kernel ,
@@ -726,3 +716,147 @@ def test_no_regression_5405(self, input_type):
726
716
output = F .resize (input , size = size , max_size = max_size , antialias = True )
727
717
728
718
assert max (F .get_spatial_size (output )) == max_size
719
+
720
+
721
+ class TestHorizontalFlip :
722
+ def _make_input (self , input_type , * , dtype = None , device = "cpu" , spatial_size = (17 , 11 ), ** kwargs ):
723
+ if input_type in {torch .Tensor , PIL .Image .Image , datapoints .Image }:
724
+ input = make_image (size = spatial_size , dtype = dtype or torch .uint8 , device = device , ** kwargs )
725
+ if input_type is torch .Tensor :
726
+ input = input .as_subclass (torch .Tensor )
727
+ elif input_type is PIL .Image .Image :
728
+ input = F .to_image_pil (input )
729
+ elif input_type is datapoints .BoundingBox :
730
+ kwargs .setdefault ("format" , datapoints .BoundingBoxFormat .XYXY )
731
+ input = make_bounding_box (
732
+ dtype = dtype or torch .float32 ,
733
+ device = device ,
734
+ spatial_size = spatial_size ,
735
+ ** kwargs ,
736
+ )
737
+ elif input_type is datapoints .Mask :
738
+ input = make_segmentation_mask (size = spatial_size , dtype = dtype or torch .uint8 , device = device , ** kwargs )
739
+ elif input_type is datapoints .Video :
740
+ input = make_video (size = spatial_size , dtype = dtype or torch .uint8 , device = device , ** kwargs )
741
+
742
+ return input
743
+
744
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
745
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
746
+ def test_kernel_image_tensor (self , dtype , device ):
747
+ check_kernel (F .horizontal_flip_image_tensor , self ._make_input (torch .Tensor ))
748
+
749
+ @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
750
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
751
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
752
+ def test_kernel_bounding_box (self , format , dtype , device ):
753
+ bounding_box = self ._make_input (datapoints .BoundingBox , dtype = dtype , device = device , format = format )
754
+ check_kernel (
755
+ F .horizontal_flip_bounding_box ,
756
+ bounding_box ,
757
+ format = format ,
758
+ spatial_size = bounding_box .spatial_size ,
759
+ )
760
+
761
+ @pytest .mark .parametrize (
762
+ "dtype_and_make_mask" , [(torch .uint8 , make_segmentation_mask ), (torch .bool , make_detection_mask )]
763
+ )
764
+ def test_kernel_mask (self , dtype_and_make_mask ):
765
+ dtype , make_mask = dtype_and_make_mask
766
+ check_kernel (F .horizontal_flip_mask , make_mask (dtype = dtype ))
767
+
768
+ def test_kernel_video (self ):
769
+ check_kernel (F .horizontal_flip_video , self ._make_input (datapoints .Video ))
770
+
771
+ @pytest .mark .parametrize (
772
+ ("input_type" , "kernel" ),
773
+ [
774
+ (torch .Tensor , F .horizontal_flip_image_tensor ),
775
+ (PIL .Image .Image , F .horizontal_flip_image_pil ),
776
+ (datapoints .Image , F .horizontal_flip_image_tensor ),
777
+ (datapoints .BoundingBox , F .horizontal_flip_bounding_box ),
778
+ (datapoints .Mask , F .horizontal_flip_mask ),
779
+ (datapoints .Video , F .horizontal_flip_video ),
780
+ ],
781
+ )
782
+ def test_dispatcher (self , kernel , input_type ):
783
+ check_dispatcher (F .horizontal_flip , kernel , self ._make_input (input_type ))
784
+
785
+ @pytest .mark .parametrize (
786
+ ("input_type" , "kernel" ),
787
+ [
788
+ (torch .Tensor , F .resize_image_tensor ),
789
+ (PIL .Image .Image , F .resize_image_pil ),
790
+ (datapoints .Image , F .resize_image_tensor ),
791
+ (datapoints .BoundingBox , F .resize_bounding_box ),
792
+ (datapoints .Mask , F .resize_mask ),
793
+ (datapoints .Video , F .resize_video ),
794
+ ],
795
+ )
796
+ def test_dispatcher_signature (self , kernel , input_type ):
797
+ check_dispatcher_signatures_match (F .resize , kernel = kernel , input_type = input_type )
798
+
799
+ @pytest .mark .parametrize (
800
+ "input_type" ,
801
+ [torch .Tensor , PIL .Image .Image , datapoints .Image , datapoints .BoundingBox , datapoints .Mask , datapoints .Video ],
802
+ )
803
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
804
+ def test_transform (self , input_type , device ):
805
+ input = self ._make_input (input_type , device = device )
806
+
807
+ check_transform (transforms .RandomHorizontalFlip , input , p = 1 )
808
+
809
+ @pytest .mark .parametrize (
810
+ "fn" , [F .horizontal_flip , transform_cls_to_functional (transforms .RandomHorizontalFlip , p = 1 )]
811
+ )
812
+ def test_image_correctness (self , fn ):
813
+ image = self ._make_input (torch .Tensor , dtype = torch .uint8 , device = "cpu" )
814
+
815
+ actual = fn (image )
816
+ expected = F .to_image_tensor (F .horizontal_flip (F .to_image_pil (image )))
817
+
818
+ torch .testing .assert_close (actual , expected )
819
+
820
+ def _reference_horizontal_flip_bounding_box (self , bounding_box ):
821
+ affine_matrix = np .array (
822
+ [
823
+ [- 1 , 0 , bounding_box .spatial_size [1 ]],
824
+ [0 , 1 , 0 ],
825
+ ],
826
+ dtype = "float64" if bounding_box .dtype == torch .float64 else "float32" ,
827
+ )
828
+
829
+ expected_bboxes = reference_affine_bounding_box_helper (
830
+ bounding_box ,
831
+ format = bounding_box .format ,
832
+ spatial_size = bounding_box .spatial_size ,
833
+ affine_matrix = affine_matrix ,
834
+ )
835
+
836
+ return datapoints .BoundingBox .wrap_like (bounding_box , expected_bboxes )
837
+
838
+ @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
839
+ @pytest .mark .parametrize (
840
+ "fn" , [F .horizontal_flip , transform_cls_to_functional (transforms .RandomHorizontalFlip , p = 1 )]
841
+ )
842
+ def test_bounding_box_correctness (self , format , fn ):
843
+ bounding_box = self ._make_input (datapoints .BoundingBox )
844
+
845
+ actual = fn (bounding_box )
846
+ expected = self ._reference_horizontal_flip_bounding_box (bounding_box )
847
+
848
+ torch .testing .assert_close (actual , expected )
849
+
850
+ @pytest .mark .parametrize (
851
+ "input_type" ,
852
+ [torch .Tensor , PIL .Image .Image , datapoints .Image , datapoints .BoundingBox , datapoints .Mask , datapoints .Video ],
853
+ )
854
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
855
+ def test_transform_noop (self , input_type , device ):
856
+ input = self ._make_input (input_type , device = device )
857
+
858
+ transform = transforms .RandomHorizontalFlip (p = 0 )
859
+
860
+ output = transform (input )
861
+
862
+ assert_equal (output , input )
0 commit comments