Skip to content

Commit 6b1646c

Browse files
MaxVit model (#6342)
* Added maxvit architecture and tests * rebased + addresed comments * Revert "rebased + addresed comments" This reverts commit c5b2839. * Re-added model changes after revert * aligned with partial original implementation * removed submitit script fixed lint * mypy fix for too many arguments * updated old tests * removed per batch lr scheduler and seed setting * removed ontap * added docs, validated weights * fixed test expect, moved shape assertions in the begging for torch.fx compatibility * mypy fix * lint fix * added legacy interface * added weight link * updated docs * Update references/classification/train.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update torchvision/models/maxvit.py Co-authored-by: Vasilis Vryniotis <[email protected]> * adressed comments * update ra_maginuted and augmix_severity default values * adressed some comments * remove input_channels parameter Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent d65e286 commit 6b1646c

File tree

9 files changed

+940
-6
lines changed

9 files changed

+940
-6
lines changed

docs/source/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ weights:
207207
models/efficientnetv2
208208
models/googlenet
209209
models/inception
210+
models/maxvit
210211
models/mnasnet
211212
models/mobilenetv2
212213
models/mobilenetv3

docs/source/models/maxvit.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
MaxVit
2+
===============
3+
4+
.. currentmodule:: torchvision.models
5+
6+
The MaxVit transformer models are based on the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`__
7+
paper.
8+
9+
10+
Model builders
11+
--------------
12+
13+
The following model builders can be used to instantiate an MaxVit model with and without pre-trained weights.
14+
All the model builders internally rely on the ``torchvision.models.maxvit.MaxVit``
15+
base class. Please refer to the `source code
16+
<https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_ for
17+
more details about this class.
18+
19+
.. autosummary::
20+
:toctree: generated/
21+
:template: function.rst
22+
23+
maxvit_t

references/classification/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`.
245245
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
246246

247247

248+
### MaxViT
249+
```
250+
torchrun --nproc_per_node=8 --n_nodes=4 train.py\
251+
--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --train-center-crop --model-ema --val-resize-size 224
252+
--val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn
253+
```
254+
Here `$MODEL` is `maxvit_t`.
255+
Note that `--val-resize-size` was not optimized in a post-training step.
248256

249257

250258
### ShuffleNet V2

references/classification/presets.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,25 @@ def __init__(
1313
interpolation=InterpolationMode.BILINEAR,
1414
hflip_prob=0.5,
1515
auto_augment_policy=None,
16+
ra_magnitude=9,
17+
augmix_severity=3,
1618
random_erase_prob=0.0,
19+
center_crop=False,
1720
):
18-
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
21+
trans = (
22+
[transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
23+
if center_crop
24+
else [transforms.CenterCrop(crop_size)]
25+
)
1926
if hflip_prob > 0:
2027
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
2128
if auto_augment_policy is not None:
2229
if auto_augment_policy == "ra":
23-
trans.append(autoaugment.RandAugment(interpolation=interpolation))
30+
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
2431
elif auto_augment_policy == "ta_wide":
2532
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
2633
elif auto_augment_policy == "augmix":
27-
trans.append(autoaugment.AugMix(interpolation=interpolation))
34+
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
2835
else:
2936
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
3037
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))

references/classification/train.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,12 @@ def _get_cache_path(filepath):
113113
def load_data(traindir, valdir, args):
114114
# Data loading code
115115
print("Loading data")
116-
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
116+
val_resize_size, val_crop_size, train_crop_size, center_crop = (
117+
args.val_resize_size,
118+
args.val_crop_size,
119+
args.train_crop_size,
120+
args.train_center_crop,
121+
)
117122
interpolation = InterpolationMode(args.interpolation)
118123

119124
print("Loading training data")
@@ -126,13 +131,18 @@ def load_data(traindir, valdir, args):
126131
else:
127132
auto_augment_policy = getattr(args, "auto_augment", None)
128133
random_erase_prob = getattr(args, "random_erase", 0.0)
134+
ra_magnitude = args.ra_magnitude
135+
augmix_severity = args.augmix_severity
129136
dataset = torchvision.datasets.ImageFolder(
130137
traindir,
131138
presets.ClassificationPresetTrain(
139+
center_crop=center_crop,
132140
crop_size=train_crop_size,
133141
interpolation=interpolation,
134142
auto_augment_policy=auto_augment_policy,
135143
random_erase_prob=random_erase_prob,
144+
ra_magnitude=ra_magnitude,
145+
augmix_severity=augmix_severity,
136146
),
137147
)
138148
if args.cache_dataset:
@@ -207,7 +217,10 @@ def main(args):
207217
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
208218
if mixup_transforms:
209219
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
210-
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
220+
221+
def collate_fn(batch):
222+
return mixupcutmix(*default_collate(batch))
223+
211224
data_loader = torch.utils.data.DataLoader(
212225
dataset,
213226
batch_size=args.batch_size,
@@ -448,6 +461,8 @@ def get_args_parser(add_help=True):
448461
action="store_true",
449462
)
450463
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
464+
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
465+
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
451466
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
452467

453468
# Mixed precision training parameters
@@ -486,13 +501,17 @@ def get_args_parser(add_help=True):
486501
parser.add_argument(
487502
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
488503
)
504+
parser.add_argument(
505+
"--train-center-crop",
506+
action="store_true",
507+
help="use center crop instead of random crop for training (default: False)",
508+
)
489509
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
490510
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
491511
parser.add_argument(
492512
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
493513
)
494514
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
495-
496515
return parser
497516

498517

1.05 KB
Binary file not shown.

test/test_architecture_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
3+
import pytest
4+
import torch
5+
6+
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition
7+
8+
9+
class MaxvitTester(unittest.TestCase):
10+
def test_maxvit_window_partition(self):
11+
input_shape = (1, 3, 224, 224)
12+
partition_size = 7
13+
n_partitions = input_shape[3] // partition_size
14+
15+
x = torch.randn(input_shape)
16+
17+
partition = WindowPartition()
18+
departition = WindowDepartition()
19+
20+
x_hat = partition(x, partition_size)
21+
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
22+
23+
assert torch.allclose(x, x_hat)
24+
25+
def test_maxvit_grid_partition(self):
26+
input_shape = (1, 3, 224, 224)
27+
partition_size = 7
28+
n_partitions = input_shape[3] // partition_size
29+
30+
x = torch.randn(input_shape)
31+
pre_swap = SwapAxes(-2, -3)
32+
post_swap = SwapAxes(-2, -3)
33+
34+
partition = WindowPartition()
35+
departition = WindowDepartition()
36+
37+
x_hat = partition(x, n_partitions)
38+
x_hat = pre_swap(x_hat)
39+
x_hat = post_swap(x_hat)
40+
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
41+
42+
assert torch.allclose(x, x_hat)
43+
44+
45+
if __name__ == "__main__":
46+
pytest.main([__file__])

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
from .vgg import *
1414
from .vision_transformer import *
1515
from .swin_transformer import *
16+
from .maxvit import *
1617
from . import detection, optical_flow, quantization, segmentation, video
1718
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models

0 commit comments

Comments
 (0)