diff --git a/docs/source/models.rst b/docs/source/models.rst index 57eda6d38a5..10618434f9b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -207,6 +207,7 @@ weights: models/efficientnetv2 models/googlenet models/inception + models/maxvit models/mnasnet models/mobilenetv2 models/mobilenetv3 diff --git a/docs/source/models/maxvit.rst b/docs/source/models/maxvit.rst new file mode 100644 index 00000000000..29aaaaab334 --- /dev/null +++ b/docs/source/models/maxvit.rst @@ -0,0 +1,23 @@ +MaxVit +=============== + +.. currentmodule:: torchvision.models + +The MaxVit transformer models are based on the `MaxViT: Multi-Axis Vision Transformer `__ +paper. + + +Model builders +-------------- + +The following model builders can be used to instantiate an MaxVit model with and without pre-trained weights. +All the model builders internally rely on the ``torchvision.models.maxvit.MaxVit`` +base class. Please refer to the `source code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + maxvit_t diff --git a/references/classification/README.md b/references/classification/README.md index e8d62134ca2..04db3837016 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -245,6 +245,14 @@ Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`. Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. +### MaxViT +``` +torchrun --nproc_per_node=8 --n_nodes=4 train.py\ +--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --train-center-crop --model-ema --val-resize-size 224 +--val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn +``` +Here `$MODEL` is `maxvit_t`. +Note that `--val-resize-size` was not optimized in a post-training step. ### ShuffleNet V2 diff --git a/references/classification/presets.py b/references/classification/presets.py index 6bc38e72953..c6028a3417b 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -13,18 +13,25 @@ def __init__( interpolation=InterpolationMode.BILINEAR, hflip_prob=0.5, auto_augment_policy=None, + ra_magnitude=9, + augmix_severity=3, random_erase_prob=0.0, + center_crop=False, ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + trans = ( + [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if center_crop + else [transforms.CenterCrop(crop_size)] + ) if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation)) + trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) elif auto_augment_policy == "ta_wide": trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation)) + trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) diff --git a/references/classification/train.py b/references/classification/train.py index 14360b042ed..f359739b113 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -113,7 +113,12 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size + val_resize_size, val_crop_size, train_crop_size, center_crop = ( + args.val_resize_size, + args.val_crop_size, + args.train_crop_size, + args.train_center_crop, + ) interpolation = InterpolationMode(args.interpolation) print("Loading training data") @@ -126,13 +131,18 @@ def load_data(traindir, valdir, args): else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) + ra_magnitude = args.ra_magnitude + augmix_severity = args.augmix_severity dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( + center_crop=center_crop, crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, + ra_magnitude=ra_magnitude, + augmix_severity=augmix_severity, ), ) if args.cache_dataset: @@ -207,7 +217,10 @@ def main(args): mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) - collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 + + def collate_fn(batch): + return mixupcutmix(*default_collate(batch)) + data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, @@ -448,6 +461,8 @@ def get_args_parser(add_help=True): action="store_true", ) parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") + parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") + parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") # Mixed precision training parameters @@ -486,13 +501,17 @@ def get_args_parser(add_help=True): parser.add_argument( "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) + parser.add_argument( + "--train-center-crop", + action="store_true", + help="use center crop instead of random crop for training (default: False)", + ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") parser.add_argument( "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - return parser diff --git a/test/expect/ModelTester.test_maxvit_t_expect.pkl b/test/expect/ModelTester.test_maxvit_t_expect.pkl new file mode 100644 index 00000000000..3a93545f614 Binary files /dev/null and b/test/expect/ModelTester.test_maxvit_t_expect.pkl differ diff --git a/test/test_architecture_ops.py b/test/test_architecture_ops.py new file mode 100644 index 00000000000..9f254c7942b --- /dev/null +++ b/test/test_architecture_ops.py @@ -0,0 +1,46 @@ +import unittest + +import pytest +import torch + +from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition + + +class MaxvitTester(unittest.TestCase): + def test_maxvit_window_partition(self): + input_shape = (1, 3, 224, 224) + partition_size = 7 + n_partitions = input_shape[3] // partition_size + + x = torch.randn(input_shape) + + partition = WindowPartition() + departition = WindowDepartition() + + x_hat = partition(x, partition_size) + x_hat = departition(x_hat, partition_size, n_partitions, n_partitions) + + assert torch.allclose(x, x_hat) + + def test_maxvit_grid_partition(self): + input_shape = (1, 3, 224, 224) + partition_size = 7 + n_partitions = input_shape[3] // partition_size + + x = torch.randn(input_shape) + pre_swap = SwapAxes(-2, -3) + post_swap = SwapAxes(-2, -3) + + partition = WindowPartition() + departition = WindowDepartition() + + x_hat = partition(x, n_partitions) + x_hat = pre_swap(x_hat) + x_hat = post_swap(x_hat) + x_hat = departition(x_hat, n_partitions, partition_size, partition_size) + + assert torch.allclose(x, x_hat) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index f8baa05c1f6..93d96112ba1 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -13,5 +13,6 @@ from .vgg import * from .vision_transformer import * from .swin_transformer import * +from .maxvit import * from . import detection, optical_flow, quantization, segmentation, video from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models diff --git a/torchvision/models/maxvit.py b/torchvision/models/maxvit.py new file mode 100644 index 00000000000..7bf92876385 --- /dev/null +++ b/torchvision/models/maxvit.py @@ -0,0 +1,829 @@ +import math +from functools import partial +from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torchvision.models._api import register_model, Weights, WeightsEnum +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface +from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation +from torchvision.ops.stochastic_depth import StochasticDepth +from torchvision.transforms._presets import ImageClassification, InterpolationMode +from torchvision.utils import _log_api_usage_once + +__all__ = [ + "MaxVit", + "MaxVit_T_Weights", + "maxvit_t", +] + + +def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]: + return ( + (input_size[0] - kernel_size + 2 * padding) // stride + 1, + (input_size[1] - kernel_size + 2 * padding) // stride + 1, + ) + + +def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]: + """Util function to check that the input size is correct for a MaxVit configuration.""" + shapes = [] + block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1) + for _ in range(n_blocks): + block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1) + shapes.append(block_input_shape) + return shapes + + +def _get_relative_position_index(height: int, width: int) -> torch.Tensor: + coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)])) + coords_flat = torch.flatten(coords, 1) + relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += height - 1 + relative_coords[:, :, 1] += width - 1 + relative_coords[:, :, 0] *= 2 * width - 1 + return relative_coords.sum(-1) + + +class MBConv(nn.Module): + """MBConv: Mobile Inverted Residual Bottleneck. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (float): Expansion ratio in the bottleneck. + squeeze_ratio (float): Squeeze ratio in the SE Layer. + stride (int): Stride of the depthwise convolution. + activation_layer (Callable[..., nn.Module]): Activation function. + norm_layer (Callable[..., nn.Module]): Normalization function. + p_stochastic_dropout (float): Probability of stochastic depth. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: float, + squeeze_ratio: float, + stride: int, + activation_layer: Callable[..., nn.Module], + norm_layer: Callable[..., nn.Module], + p_stochastic_dropout: float = 0.0, + ) -> None: + super().__init__() + + proj: Sequence[nn.Module] + self.proj: nn.Module + + should_proj = stride != 1 or in_channels != out_channels + if should_proj: + proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)] + if stride == 2: + proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore + self.proj = nn.Sequential(*proj) + else: + self.proj = nn.Identity() # type: ignore + + mid_channels = int(out_channels * expansion_ratio) + sqz_channels = int(out_channels * squeeze_ratio) + + if p_stochastic_dropout: + self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore + else: + self.stochastic_depth = nn.Identity() # type: ignore + + _layers = OrderedDict() + _layers["pre_norm"] = norm_layer(in_channels) + _layers["conv_a"] = Conv2dNormActivation( + in_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + activation_layer=activation_layer, + norm_layer=norm_layer, + inplace=None, + ) + _layers["conv_b"] = Conv2dNormActivation( + mid_channels, + mid_channels, + kernel_size=3, + stride=stride, + padding=1, + activation_layer=activation_layer, + norm_layer=norm_layer, + groups=mid_channels, + inplace=None, + ) + _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU) + _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True) + + self.layers = nn.Sequential(_layers) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + Returns: + Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride]. + """ + res = self.proj(x) + x = self.stochastic_depth(self.layers(x)) + return res + x + + +class RelativePositionalMultiHeadAttention(nn.Module): + """Relative Positional Multi-Head Attention. + + Args: + feat_dim (int): Number of input features. + head_dim (int): Number of features per head. + max_seq_len (int): Maximum sequence length. + """ + + def __init__( + self, + feat_dim: int, + head_dim: int, + max_seq_len: int, + ) -> None: + super().__init__() + + if feat_dim % head_dim != 0: + raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}") + + self.n_heads = feat_dim // head_dim + self.head_dim = head_dim + self.size = int(math.sqrt(max_seq_len)) + self.max_seq_len = max_seq_len + + self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3) + self.scale_factor = feat_dim**-0.5 + + self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim) + self.relative_position_bias_table = nn.parameter.Parameter( + torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32), + ) + + self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size)) + # initialize with truncated normal the bias + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def get_relative_positional_bias(self) -> torch.Tensor: + bias_index = self.relative_position_index.view(-1) # type: ignore + relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore + relative_bias = relative_bias.permute(2, 0, 1).contiguous() + return relative_bias.unsqueeze(0) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, G, P, D]. + Returns: + Tensor: Output tensor with expected layout of [B, G, P, D]. + """ + B, G, P, D = x.shape + H, DH = self.n_heads, self.head_dim + + qkv = self.to_qkv(x) + q, k, v = torch.chunk(qkv, 3, dim=-1) + + q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) + k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) + v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) + + k = k * self.scale_factor + dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k) + pos_bias = self.get_relative_positional_bias() + + dot_prod = F.softmax(dot_prod + pos_bias, dim=-1) + + out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v) + out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D) + + out = self.merge(out) + return out + + +class SwapAxes(nn.Module): + """Permute the axes of a tensor.""" + + def __init__(self, a: int, b: int) -> None: + super().__init__() + self.a = a + self.b = b + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = torch.swapaxes(x, self.a, self.b) + return res + + +class WindowPartition(nn.Module): + """ + Partition the input tensor into non-overlapping windows. + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor, p: int) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + p (int): Number of partitions. + Returns: + Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C]. + """ + B, C, H, W = x.shape + P = p + # chunk up H and W dimensions + x = x.reshape(B, C, H // P, P, W // P, P) + x = x.permute(0, 2, 4, 3, 5, 1) + # colapse P * P dimension + x = x.reshape(B, (H // P) * (W // P), P * P, C) + return x + + +class WindowDepartition(nn.Module): + """ + Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W]. + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C]. + p (int): Number of partitions. + h_partitions (int): Number of vertical partitions. + w_partitions (int): Number of horizontal partitions. + Returns: + Tensor: Output tensor with expected layout of [B, C, H, W]. + """ + B, G, PP, C = x.shape + P = p + HP, WP = h_partitions, w_partitions + # split P * P dimension into 2 P tile dimensionsa + x = x.reshape(B, HP, WP, P, P, C) + # permute into B, C, HP, P, WP, P + x = x.permute(0, 5, 1, 3, 2, 4) + # reshape into B, C, H, W + x = x.reshape(B, C, HP * P, WP * P) + return x + + +class PartitionAttentionLayer(nn.Module): + """ + Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window. + + Args: + in_channels (int): Number of input channels. + head_dim (int): Dimension of each attention head. + partition_size (int): Size of the partitions. + partition_type (str): Type of partitioning to use. Can be either "grid" or "window". + grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into. + mlp_ratio (int): Ratio of the feature size expansion in the MLP layer. + activation_layer (Callable[..., nn.Module]): Activation function to use. + norm_layer (Callable[..., nn.Module]): Normalization function to use. + attention_dropout (float): Dropout probability for the attention layer. + mlp_dropout (float): Dropout probability for the MLP layer. + p_stochastic_dropout (float): Probability of dropping out a partition. + """ + + def __init__( + self, + in_channels: int, + head_dim: int, + # partitioning parameteres + partition_size: int, + partition_type: str, + # grid size needs to be known at initialization time + # because we need to know hamy relative offsets there are in the grid + grid_size: Tuple[int, int], + mlp_ratio: int, + activation_layer: Callable[..., nn.Module], + norm_layer: Callable[..., nn.Module], + attention_dropout: float, + mlp_dropout: float, + p_stochastic_dropout: float, + ) -> None: + super().__init__() + + self.n_heads = in_channels // head_dim + self.head_dim = head_dim + self.n_partitions = grid_size[0] // partition_size + self.partition_type = partition_type + self.grid_size = grid_size + + if partition_type not in ["grid", "window"]: + raise ValueError("partition_type must be either 'grid' or 'window'") + + if partition_type == "window": + self.p, self.g = partition_size, self.n_partitions + else: + self.p, self.g = self.n_partitions, partition_size + + self.partition_op = WindowPartition() + self.departition_op = WindowDepartition() + self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity() + self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity() + + self.attn_layer = nn.Sequential( + norm_layer(in_channels), + # it's always going to be partition_size ** 2 because + # of the axis swap in the case of grid partitioning + RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2), + nn.Dropout(attention_dropout), + ) + + # pre-normalization similar to transformer layers + self.mlp_layer = nn.Sequential( + nn.LayerNorm(in_channels), + nn.Linear(in_channels, in_channels * mlp_ratio), + activation_layer(), + nn.Linear(in_channels * mlp_ratio, in_channels), + nn.Dropout(mlp_dropout), + ) + + # layer scale factors + self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row") + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + Returns: + Tensor: Output tensor with expected layout of [B, C, H, W]. + """ + + # Undefined behavior if H or W are not divisible by p + # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766 + gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p + torch._assert( + self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0, + "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format( + self.grid_size, self.p + ), + ) + + x = self.partition_op(x, self.p) + x = self.partition_swap(x) + x = x + self.stochastic_dropout(self.attn_layer(x)) + x = x + self.stochastic_dropout(self.mlp_layer(x)) + x = self.departition_swap(x) + x = self.departition_op(x, self.p, gh, gw) + + return x + + +class MaxVitLayer(nn.Module): + """ + MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (float): Expansion ratio in the bottleneck. + squeeze_ratio (float): Squeeze ratio in the SE Layer. + stride (int): Stride of the depthwise convolution. + activation_layer (Callable[..., nn.Module]): Activation function. + norm_layer (Callable[..., nn.Module]): Normalization function. + head_dim (int): Dimension of the attention heads. + mlp_ratio (int): Ratio of the MLP layer. + mlp_dropout (float): Dropout probability for the MLP layer. + attention_dropout (float): Dropout probability for the attention layer. + p_stochastic_dropout (float): Probability of stochastic depth. + partition_size (int): Size of the partitions. + grid_size (Tuple[int, int]): Size of the input feature grid. + """ + + def __init__( + self, + # conv parameters + in_channels: int, + out_channels: int, + squeeze_ratio: float, + expansion_ratio: float, + stride: int, + # conv + transformer parameters + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + # transformer parameters + head_dim: int, + mlp_ratio: int, + mlp_dropout: float, + attention_dropout: float, + p_stochastic_dropout: float, + # partitioning parameters + partition_size: int, + grid_size: Tuple[int, int], + ) -> None: + super().__init__() + + layers: OrderedDict[str, Any] = OrderedDict() # type: ignore + + # convolutional layer + layers["MBconv"] = MBConv( + in_channels=in_channels, + out_channels=out_channels, + expansion_ratio=expansion_ratio, + squeeze_ratio=squeeze_ratio, + stride=stride, + activation_layer=activation_layer, + norm_layer=norm_layer, + p_stochastic_dropout=p_stochastic_dropout, + ) + # attention layers, block -> grid + layers["window_attention"] = PartitionAttentionLayer( + in_channels=out_channels, + head_dim=head_dim, + partition_size=partition_size, + partition_type="window", + grid_size=grid_size, + mlp_ratio=mlp_ratio, + activation_layer=activation_layer, + norm_layer=nn.LayerNorm, + attention_dropout=attention_dropout, + mlp_dropout=mlp_dropout, + p_stochastic_dropout=p_stochastic_dropout, + ) + layers["grid_attention"] = PartitionAttentionLayer( + in_channels=out_channels, + head_dim=head_dim, + partition_size=partition_size, + partition_type="grid", + grid_size=grid_size, + mlp_ratio=mlp_ratio, + activation_layer=activation_layer, + norm_layer=nn.LayerNorm, + attention_dropout=attention_dropout, + mlp_dropout=mlp_dropout, + p_stochastic_dropout=p_stochastic_dropout, + ) + self.layers = nn.Sequential(layers) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, H, W). + Returns: + Tensor: Output tensor of shape (B, C, H, W). + """ + x = self.layers(x) + return x + + +class MaxVitBlock(nn.Module): + """ + A MaxVit block consisting of `n_layers` MaxVit layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (float): Expansion ratio in the bottleneck. + squeeze_ratio (float): Squeeze ratio in the SE Layer. + activation_layer (Callable[..., nn.Module]): Activation function. + norm_layer (Callable[..., nn.Module]): Normalization function. + head_dim (int): Dimension of the attention heads. + mlp_ratio (int): Ratio of the MLP layer. + mlp_dropout (float): Dropout probability for the MLP layer. + attention_dropout (float): Dropout probability for the attention layer. + p_stochastic_dropout (float): Probability of stochastic depth. + partition_size (int): Size of the partitions. + input_grid_size (Tuple[int, int]): Size of the input feature grid. + n_layers (int): Number of layers in the block. + p_stochastic (List[float]): List of probabilities for stochastic depth for each layer. + """ + + def __init__( + self, + # conv parameters + in_channels: int, + out_channels: int, + squeeze_ratio: float, + expansion_ratio: float, + # conv + transformer parameters + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + # transformer parameters + head_dim: int, + mlp_ratio: int, + mlp_dropout: float, + attention_dropout: float, + # partitioning parameters + partition_size: int, + input_grid_size: Tuple[int, int], + # number of layers + n_layers: int, + p_stochastic: List[float], + ) -> None: + super().__init__() + if not len(p_stochastic) == n_layers: + raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.") + + self.layers = nn.ModuleList() + # account for the first stride of the first layer + self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1) + + for idx, p in enumerate(p_stochastic): + stride = 2 if idx == 0 else 1 + self.layers += [ + MaxVitLayer( + in_channels=in_channels if idx == 0 else out_channels, + out_channels=out_channels, + squeeze_ratio=squeeze_ratio, + expansion_ratio=expansion_ratio, + stride=stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + mlp_dropout=mlp_dropout, + attention_dropout=attention_dropout, + partition_size=partition_size, + grid_size=self.grid_size, + p_stochastic_dropout=p, + ), + ] + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, H, W). + Returns: + Tensor: Output tensor of shape (B, C, H, W). + """ + for layer in self.layers: + x = layer(x) + return x + + +class MaxVit(nn.Module): + """ + Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer `_ paper. + Args: + input_size (Tuple[int, int]): Size of the input image. + stem_channels (int): Number of channels in the stem. + partition_size (int): Size of the partitions. + block_channels (List[int]): Number of channels in each block. + block_layers (List[int]): Number of layers in each block. + stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value. + squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25. + expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4. + norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.99)`). + activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU. + head_dim (int): Dimension of the attention heads. + mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4. + mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0. + attention_dropout (float): Dropout probability for the attention layer. Default: 0.0. + num_classes (int): Number of classes. Default: 1000. + """ + + def __init__( + self, + # input size parameters + input_size: Tuple[int, int], + # stem and task parameters + stem_channels: int, + # partitioning parameters + partition_size: int, + # block parameters + block_channels: List[int], + block_layers: List[int], + # attention head dimensions + head_dim: int, + stochastic_depth_prob: float, + # conv + transformer parameters + # norm_layer is applied only to the conv layers + # activation_layer is applied both to conv and transformer layers + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Callable[..., nn.Module] = nn.GELU, + # conv parameters + squeeze_ratio: float = 0.25, + expansion_ratio: float = 4, + # transformer parameters + mlp_ratio: int = 4, + mlp_dropout: float = 0.0, + attention_dropout: float = 0.0, + # task parameters + num_classes: int = 1000, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + input_channels = 3 + + # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030 + # for the exact parameters used in batchnorm + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99) + + # Make sure input size will be divisible by the partition size in all blocks + # Undefined behavior if H or W are not divisible by p + # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766 + block_input_sizes = _make_block_input_shapes(input_size, len(block_channels)) + for idx, block_input_size in enumerate(block_input_sizes): + if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0: + raise ValueError( + f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. " + f"Consider changing the partition size or the input size.\n" + f"Current configuration yields the following block input sizes: {block_input_sizes}." + ) + + # stem + self.stem = nn.Sequential( + Conv2dNormActivation( + input_channels, + stem_channels, + 3, + stride=2, + norm_layer=norm_layer, + activation_layer=activation_layer, + bias=False, + inplace=None, + ), + Conv2dNormActivation( + stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True + ), + ) + + # account for stem stride + input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1) + self.partition_size = partition_size + + # blocks + self.blocks = nn.ModuleList() + in_channels = [stem_channels] + block_channels[:-1] + out_channels = block_channels + + # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob + # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed + # over the range [0, stochastic_depth_prob] + p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist() + + p_idx = 0 + for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers): + self.blocks.append( + MaxVitBlock( + in_channels=in_channel, + out_channels=out_channel, + squeeze_ratio=squeeze_ratio, + expansion_ratio=expansion_ratio, + norm_layer=norm_layer, + activation_layer=activation_layer, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + mlp_dropout=mlp_dropout, + attention_dropout=attention_dropout, + partition_size=partition_size, + input_grid_size=input_size, + n_layers=num_layers, + p_stochastic=p_stochastic[p_idx : p_idx + num_layers], + ), + ) + input_size = self.blocks[-1].grid_size + p_idx += num_layers + + # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158 + # for why there is Linear -> Tanh -> Linear + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.LayerNorm(block_channels[-1]), + nn.Linear(block_channels[-1], block_channels[-1]), + nn.Tanh(), + nn.Linear(block_channels[-1], num_classes, bias=False), + ) + + self._init_weights() + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + for block in self.blocks: + x = block(x) + x = self.classifier(x) + return x + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def _maxvit( + # stem parameters + stem_channels: int, + # block parameters + block_channels: List[int], + block_layers: List[int], + stochastic_depth_prob: float, + # partitioning parameters + partition_size: int, + # transformer parameters + head_dim: int, + # Weights API + weights: Optional[WeightsEnum] = None, + progress: bool = False, + # kwargs, + **kwargs: Any, +) -> MaxVit: + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"]) + + input_size = kwargs.pop("input_size", (224, 224)) + + model = MaxVit( + stem_channels=stem_channels, + block_channels=block_channels, + block_layers=block_layers, + stochastic_depth_prob=stochastic_depth_prob, + head_dim=head_dim, + partition_size=partition_size, + input_size=input_size, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +class MaxVit_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # URL empty until official release + url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC + ), + meta={ + "categories": _IMAGENET_CATEGORIES, + "num_params": 30919624, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.700, + "acc@5": 96.722, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1)) +def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit: + """ + Constructs a maxvit_t architecture from + `MaxViT: Multi-Axis Vision Transformer `_. + + Args: + weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MaxVit_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MaxVit_T_Weights + :members: + """ + weights = MaxVit_T_Weights.verify(weights) + + return _maxvit( + stem_channels=64, + block_channels=[64, 128, 256, 512], + block_layers=[2, 2, 5, 2], + head_dim=32, + stochastic_depth_prob=0.2, + partition_size=7, + weights=weights, + progress=progress, + **kwargs, + )