Skip to content

Commit f15fd92

Browse files
committed
Added maxvit architecture and tests
1 parent b30fa5c commit f15fd92

8 files changed

+643
-0
lines changed
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

test/test_architecture_ops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
3+
import pytest
4+
import torch
5+
6+
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition
7+
8+
9+
class MaxvitTester(unittest.TestCase):
10+
def test_maxvit_window_partition(self):
11+
input_shape = (1, 3, 224, 224)
12+
partition_size = 7
13+
14+
x = torch.randn(input_shape)
15+
16+
partition = WindowPartition(partition_size=7)
17+
departition = WindowDepartition(partition_size=partition_size, n_partitions=(input_shape[3] // partition_size))
18+
19+
assert torch.allclose(x, departition(partition(x)))
20+
21+
def test_maxvit_grid_partition(self):
22+
input_shape = (1, 3, 224, 224)
23+
partition_size = 7
24+
25+
x = torch.randn(input_shape)
26+
partition = torch.nn.Sequential(
27+
WindowPartition(partition_size=(input_shape[3] // partition_size)),
28+
SwapAxes(-2, -3),
29+
)
30+
departition = torch.nn.Sequential(
31+
SwapAxes(-2, -3),
32+
WindowDepartition(partition_size=(input_shape[3] // partition_size), n_partitions=partition_size),
33+
)
34+
35+
assert torch.allclose(x, departition(partition(x)))
36+
37+
38+
if __name__ == "__main__":
39+
pytest.main([__file__])

test/test_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,14 @@ def test_vitc_models(model_fn, dev):
594594
test_classification_model(model_fn, dev)
595595

596596

597+
@pytest.mark.parametrize(
598+
"model_fn", [models.max_vit_T_224, models.max_vit_S_224, models.max_vit_B_224, models.max_vit_L_224]
599+
)
600+
@pytest.mark.parametrize("dev", cpu_and_gpu())
601+
def test_max_vit(model_fn, dev):
602+
test_classification_model(model_fn, dev)
603+
604+
597605
@pytest.mark.parametrize("model_fn", list_model_fns(models))
598606
@pytest.mark.parametrize("dev", cpu_and_gpu())
599607
def test_classification_model(model_fn, dev):

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
from .vgg import *
1414
from .vision_transformer import *
1515
from .swin_transformer import *
16+
from .maxvit import *
1617
from . import detection, optical_flow, quantization, segmentation, video
1718
from ._api import get_model, get_model_weights, get_weight, list_models

0 commit comments

Comments
 (0)