@@ -10,29 +10,36 @@ class MaxvitTester(unittest.TestCase):
1010 def test_maxvit_window_partition (self ):
1111 input_shape = (1 , 3 , 224 , 224 )
1212 partition_size = 7
13+ n_partitions = input_shape [3 ] // partition_size
1314
1415 x = torch .randn (input_shape )
1516
16- partition = WindowPartition (partition_size = 7 )
17- departition = WindowDepartition (partition_size = partition_size , n_partitions = ( input_shape [ 3 ] // partition_size ) )
17+ partition = WindowPartition ()
18+ departition = WindowDepartition ()
1819
19- assert torch .allclose (x , departition (partition (x )))
20+ x_hat = partition (x , partition_size )
21+ x_hat = departition (x_hat , partition_size , n_partitions , n_partitions )
22+
23+ assert torch .allclose (x , x_hat )
2024
2125 def test_maxvit_grid_partition (self ):
2226 input_shape = (1 , 3 , 224 , 224 )
2327 partition_size = 7
28+ n_partitions = input_shape [3 ] // partition_size
2429
2530 x = torch .randn (input_shape )
26- partition = torch .nn .Sequential (
27- WindowPartition (partition_size = (input_shape [3 ] // partition_size )),
28- SwapAxes (- 2 , - 3 ),
29- )
30- departition = torch .nn .Sequential (
31- SwapAxes (- 2 , - 3 ),
32- WindowDepartition (partition_size = (input_shape [3 ] // partition_size ), n_partitions = partition_size ),
33- )
34-
35- assert torch .allclose (x , departition (partition (x )))
31+ pre_swap = SwapAxes (- 2 , - 3 )
32+ post_swap = SwapAxes (- 2 , - 3 )
33+
34+ partition = WindowPartition ()
35+ departition = WindowDepartition ()
36+
37+ x_hat = partition (x , n_partitions )
38+ x_hat = pre_swap (x_hat )
39+ x_hat = post_swap (x_hat )
40+ x_hat = departition (x_hat , n_partitions , partition_size , partition_size )
41+
42+ assert torch .allclose (x , x_hat )
3643
3744
3845if __name__ == "__main__" :
0 commit comments