Skip to content

Commit f561edf

Browse files
committed
updated old tests
1 parent 872f40f commit f561edf

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

test/test_architecture_ops.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,36 @@ class MaxvitTester(unittest.TestCase):
1010
def test_maxvit_window_partition(self):
1111
input_shape = (1, 3, 224, 224)
1212
partition_size = 7
13+
n_partitions = input_shape[3] // partition_size
1314

1415
x = torch.randn(input_shape)
1516

16-
partition = WindowPartition(partition_size=7)
17-
departition = WindowDepartition(partition_size=partition_size, n_partitions=(input_shape[3] // partition_size))
17+
partition = WindowPartition()
18+
departition = WindowDepartition()
1819

19-
assert torch.allclose(x, departition(partition(x)))
20+
x_hat = partition(x, partition_size)
21+
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
22+
23+
assert torch.allclose(x, x_hat)
2024

2125
def test_maxvit_grid_partition(self):
2226
input_shape = (1, 3, 224, 224)
2327
partition_size = 7
28+
n_partitions = input_shape[3] // partition_size
2429

2530
x = torch.randn(input_shape)
26-
partition = torch.nn.Sequential(
27-
WindowPartition(partition_size=(input_shape[3] // partition_size)),
28-
SwapAxes(-2, -3),
29-
)
30-
departition = torch.nn.Sequential(
31-
SwapAxes(-2, -3),
32-
WindowDepartition(partition_size=(input_shape[3] // partition_size), n_partitions=partition_size),
33-
)
34-
35-
assert torch.allclose(x, departition(partition(x)))
31+
pre_swap = SwapAxes(-2, -3)
32+
post_swap = SwapAxes(-2, -3)
33+
34+
partition = WindowPartition()
35+
departition = WindowDepartition()
36+
37+
x_hat = partition(x, n_partitions)
38+
x_hat = pre_swap(x_hat)
39+
x_hat = post_swap(x_hat)
40+
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
41+
42+
assert torch.allclose(x, x_hat)
3643

3744

3845
if __name__ == "__main__":

0 commit comments

Comments
 (0)