@@ -2220,3 +2220,61 @@ def resize_my_datapoint():
22202220 _register_kernel_internal (F .resize , MyDatapoint , datapoint_wrapper = False )(resize_my_datapoint )
22212221
22222222 assert _get_kernel (F .resize , MyDatapoint ) is resize_my_datapoint
2223+
2224+
2225+ class TestPermuteChannels :
2226+ _DEFAULT_PERMUTATION = [2 , 0 , 1 ]
2227+
2228+ @pytest .mark .parametrize (
2229+ ("kernel" , "make_input" ),
2230+ [
2231+ (F .permute_channels_image_tensor , make_image_tensor ),
2232+ # FIXME
2233+ # check_kernel does not support PIL kernel, but it should
2234+ (F .permute_channels_image_tensor , make_image ),
2235+ (F .permute_channels_video , make_video ),
2236+ ],
2237+ )
2238+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
2239+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2240+ def test_kernel (self , kernel , make_input , dtype , device ):
2241+ check_kernel (kernel , make_input (dtype = dtype , device = device ), permutation = self ._DEFAULT_PERMUTATION )
2242+
2243+ @pytest .mark .parametrize (
2244+ ("kernel" , "make_input" ),
2245+ [
2246+ (F .permute_channels_image_tensor , make_image_tensor ),
2247+ (F .permute_channels_image_pil , make_image_pil ),
2248+ (F .permute_channels_image_tensor , make_image ),
2249+ (F .permute_channels_video , make_video ),
2250+ ],
2251+ )
2252+ def test_dispatcher (self , kernel , make_input ):
2253+ check_dispatcher (F .permute_channels , kernel , make_input (), permutation = self ._DEFAULT_PERMUTATION )
2254+
2255+ @pytest .mark .parametrize (
2256+ ("kernel" , "input_type" ),
2257+ [
2258+ (F .permute_channels_image_tensor , torch .Tensor ),
2259+ (F .permute_channels_image_pil , PIL .Image .Image ),
2260+ (F .permute_channels_image_tensor , datapoints .Image ),
2261+ (F .permute_channels_video , datapoints .Video ),
2262+ ],
2263+ )
2264+ def test_dispatcher_signature (self , kernel , input_type ):
2265+ check_dispatcher_kernel_signature_match (F .permute_channels , kernel = kernel , input_type = input_type )
2266+
2267+ def reference_image_correctness (self , image , permutation ):
2268+ channel_images = image .split (1 , dim = - 3 )
2269+ permuted_channel_images = [channel_images [channel_idx ] for channel_idx in permutation ]
2270+ return datapoints .Image (torch .concat (permuted_channel_images , dim = - 3 ))
2271+
2272+ @pytest .mark .parametrize ("permutation" , [[2 , 0 , 1 ], [1 , 2 , 0 ], [2 , 0 , 1 ], [0 , 1 , 2 ]])
2273+ @pytest .mark .parametrize ("batch_dims" , [(), (2 ,), (2 , 1 )])
2274+ def test_image_correctness (self , permutation , batch_dims ):
2275+ image = make_image (batch_dims = batch_dims )
2276+
2277+ actual = F .permute_channels (image , permutation = permutation )
2278+ expected = self .reference_image_correctness (image , permutation = permutation )
2279+
2280+ torch .testing .assert_close (actual , expected )
0 commit comments