Skip to content

Commit aa95139

Browse files
committed
Re-added model changes after revert
1 parent 5e8a222 commit aa95139

File tree

6 files changed

+105
-136
lines changed

6 files changed

+105
-136
lines changed
-939 Bytes
Binary file not shown.
-939 Bytes
Binary file not shown.
-939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -594,14 +594,6 @@ def test_vitc_models(model_fn, dev):
594594
test_classification_model(model_fn, dev)
595595

596596

597-
@pytest.mark.parametrize(
598-
"model_fn", [models.max_vit_T_224, models.max_vit_S_224, models.max_vit_B_224, models.max_vit_L_224]
599-
)
600-
@pytest.mark.parametrize("dev", cpu_and_gpu())
601-
def test_max_vit(model_fn, dev):
602-
test_classification_model(model_fn, dev)
603-
604-
605597
@pytest.mark.parametrize("model_fn", list_model_fns(models))
606598
@pytest.mark.parametrize("dev", cpu_and_gpu())
607599
def test_classification_model(model_fn, dev):

torchvision/models/maxvit.py

Lines changed: 105 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import math
2-
from typing import Any, Callable, List, OrderedDict, Sequence, Tuple
2+
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Tuple
33

44
import numpy as np
55
import torch
66
import torch.nn.functional as F
77
from torch import nn, Tensor
8+
from torchvision.models._api import register_model, WeightsEnum
9+
from torchvision.models._utils import _ovewrite_named_param
810
from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
911
from torchvision.ops.stochastic_depth import StochasticDepth
12+
from torchvision.utils import _log_api_usage_once
1013

1114

1215
def get_relative_position_index(height: int, width: int) -> torch.Tensor:
@@ -20,20 +23,6 @@ def get_relative_position_index(height: int, width: int) -> torch.Tensor:
2023
return relative_coords.sum(-1)
2124

2225

23-
class GeluWrapper(nn.Module):
24-
"""
25-
Gelu wrapper to make it compatible with `ConvNormActivation2D` which passed inplace=True
26-
to the activation function construction.
27-
"""
28-
29-
def __init__(self, **kwargs) -> None:
30-
super().__init__()
31-
self._op = F.gelu
32-
33-
def forward(self, x: Tensor) -> Tensor:
34-
return self._op(x)
35-
36-
3726
class MBConv(nn.Module):
3827
def __init__(
3928
self,
@@ -65,20 +54,28 @@ def __init__(
6554
_layers = OrderedDict()
6655
_layers["pre_norm"] = normalization_fn(in_channels)
6756
_layers["conv_a"] = Conv2dNormActivation(
68-
in_channels, mid_channels, 1, 1, 0, activation_layer=activation_fn, norm_layer=normalization_fn
57+
in_channels,
58+
mid_channels,
59+
kernel_size=1,
60+
stride=1,
61+
padding=0,
62+
activation_layer=activation_fn,
63+
norm_layer=normalization_fn,
64+
inplace=None,
6965
)
7066
_layers["conv_b"] = Conv2dNormActivation(
7167
mid_channels,
7268
mid_channels,
73-
3,
74-
stride,
75-
1,
69+
kernel_size=3,
70+
stride=stride,
71+
padding=1,
7672
activation_layer=activation_fn,
7773
norm_layer=normalization_fn,
7874
groups=mid_channels,
75+
inplace=None,
7976
)
8077
_layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels)
81-
_layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=False)
78+
_layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
8279

8380
self.layers = nn.Sequential(_layers)
8481

@@ -116,14 +113,13 @@ def __init__(
116113
# initialize with truncated normal the bias
117114
self.positional_bias.data.normal_(mean=0, std=0.02)
118115

119-
def _get_relative_positional_bias(self) -> torch.Tensor:
116+
def get_relative_positional_bias(self) -> torch.Tensor:
120117
bias_index = self.relative_position_index.view(-1) # type: ignore
121118
relative_bias = self.positional_bias[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore
122119
relative_bias = relative_bias.permute(2, 0, 1).contiguous()
123120
return relative_bias.unsqueeze(0)
124121

125122
def forward(self, x: Tensor) -> Tensor:
126-
# X, Y and stand for X-axis group dim, Y-axis group dim
127123
B, G, P, D = x.shape
128124
H, DH = self.n_heads, self.head_dim
129125

@@ -135,9 +131,8 @@ def forward(self, x: Tensor) -> Tensor:
135131
v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
136132

137133
k = k * self.scale_factor
138-
# X, Y and stand for X-axis group dim, Y-axis group dim
139134
dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
140-
pos_bias = self._get_relative_positional_bias()
135+
pos_bias = self.get_relative_positional_bias()
141136

142137
dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
143138

@@ -204,34 +199,6 @@ def forward(self, x: Tensor) -> Tensor:
204199
return x
205200

206201

207-
class MLP(nn.Module):
208-
def __init__(
209-
self,
210-
in_dim: int,
211-
hidden_dim: int,
212-
activation_fn: Callable[..., nn.Module],
213-
normalization_fn: Callable[..., nn.Module],
214-
dropout: float,
215-
) -> None:
216-
super().__init__()
217-
self.in_dim = in_dim
218-
self.hidden_dim = hidden_dim
219-
self.activation_fn = activation_fn
220-
self.normalization_fn = normalization_fn
221-
self.dropout = dropout
222-
223-
self.layers = nn.Sequential(
224-
self.normalization_fn(in_dim),
225-
nn.Linear(in_dim, hidden_dim),
226-
self.activation_fn(),
227-
nn.Linear(hidden_dim, in_dim),
228-
nn.Dropout(dropout),
229-
)
230-
231-
def forward(self, x: Tensor) -> Tensor:
232-
return x + self.layers(x)
233-
234-
235202
class PartitionAttentionLayer(nn.Module):
236203
def __init__(
237204
self,
@@ -282,16 +249,23 @@ def __init__(
282249
nn.Dropout(attn_dropout),
283250
)
284251

285-
self.mlp_layer = MLP(in_channels, in_channels * mlp_ratio, activation_fn, normalization_fn, mlp_dropout)
252+
# pre-normalization similar to transformer layers
253+
self.mlp_layer = nn.Sequential(
254+
nn.LayerNorm(in_channels),
255+
nn.Linear(in_channels, in_channels * mlp_ratio),
256+
activation_fn(),
257+
nn.Linear(in_channels * mlp_ratio, in_channels),
258+
nn.Dropout(mlp_dropout),
259+
)
286260

287261
# layer scale factors
288262
self.attn_layer_scale = nn.parameter.Parameter(torch.ones(in_channels) * 1e-6)
289263
self.mlp_layer_scale = nn.parameter.Parameter(torch.ones(in_channels) * 1e-6)
290264

291265
def forward(self, x: Tensor) -> Tensor:
292266
x = self.partition_op(x)
293-
x = self.attn_layer(x) * self.attn_layer_scale
294-
x = self.mlp_layer(x) * self.mlp_layer_scale
267+
x = x + self.attn_layer(x) * self.attn_layer_scale
268+
x = x + self.mlp_layer(x) * self.mlp_layer_scale
295269
x = self.departition_op(x)
296270
return x
297271

@@ -386,9 +360,8 @@ def __init__(
386360
p_stochastic: List[float],
387361
) -> None:
388362
super().__init__()
389-
assert (
390-
len(p_stochastic) == n_layers
391-
), f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}."
363+
if not len(p_stochastic) == n_layers:
364+
raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
392365

393366
self.layers = nn.ModuleList()
394367
# account for the first stride of the first layer
@@ -424,11 +397,12 @@ def forward(self, x: Tensor) -> Tensor:
424397
class MaxVit(nn.Module):
425398
def __init__(
426399
self,
400+
# input size parameters
401+
input_size: Tuple[int, int],
427402
# stem and task parameters
428403
input_channels: int,
429404
stem_channels: int,
430-
input_size: Tuple[int, int],
431-
out_classes: int,
405+
num_classes: int,
432406
# block parameters
433407
block_channels: List[int],
434408
block_layers: List[int],
@@ -450,6 +424,7 @@ def __init__(
450424
partition_size: int,
451425
) -> None:
452426
super().__init__()
427+
_log_api_usage_once(self)
453428

454429
# stem
455430
self.stem = nn.Sequential(
@@ -500,7 +475,7 @@ def __init__(
500475
self.classifier = nn.Sequential(
501476
nn.AdaptiveAvgPool2d(1),
502477
nn.Flatten(),
503-
nn.Linear(block_channels[-1], out_classes, bias=False),
478+
nn.Linear(block_channels[-1], num_classes, bias=False),
504479
)
505480

506481
def forward(self, x: Tensor) -> Tensor:
@@ -511,85 +486,87 @@ def forward(self, x: Tensor) -> Tensor:
511486
return x
512487

513488

514-
def max_vit_T_224(num_classes: int) -> MaxVit:
515-
return MaxVit(
516-
input_channels=3,
517-
stem_channels=64,
518-
input_size=(224, 224),
519-
out_classes=num_classes,
520-
block_channels=[64, 128, 256, 512],
521-
block_layers=[2, 2, 5, 2],
522-
stochastic_depth_prob=0.2,
523-
squeeze_ratio=0.25,
524-
expansion_ratio=4.0,
525-
normalization_fn=nn.BatchNorm2d,
526-
activation_fn=GeluWrapper,
527-
head_dim=32,
528-
mlp_ratio=2,
529-
mlp_dropout=0.0,
530-
attn_dropout=0.0,
531-
partition_size=7,
489+
def _maxvit(
490+
# stem and task parameters
491+
stem_channels: int,
492+
num_classes: int,
493+
# block parameters
494+
block_channels: List[int],
495+
block_layers: List[int],
496+
stochastic_depth_prob: float,
497+
# conv parameters
498+
squeeze_ratio: float,
499+
expansion_ratio: float,
500+
# conv + transformer parameters
501+
# normalization_fn is applied only to the conv layers
502+
# activation_fn is applied both to conv and transformer layers
503+
normalization_fn: Callable[..., nn.Module],
504+
activation_fn: Callable[..., nn.Module],
505+
# transformer parameters
506+
head_dim: int,
507+
mlp_ratio: int,
508+
mlp_dropout: float,
509+
attn_dropout: float,
510+
# partitioning parameters
511+
partition_size: int,
512+
# Weights API
513+
weights: Optional[WeightsEnum],
514+
progress: bool,
515+
# kwargs,
516+
**kwargs,
517+
) -> MaxVit:
518+
if weights is not None:
519+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
520+
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
521+
_ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"][0])
522+
_ovewrite_named_param(kwargs, "input_channels", weights.meta["input_channels"])
523+
524+
input_size = kwargs.pop("input_size", (224, 224))
525+
input_channels = kwargs.pop("input_channels", 3)
526+
527+
model = MaxVit(
528+
input_channels=input_channels,
529+
stem_channels=stem_channels,
530+
num_classes=num_classes,
531+
block_channels=block_channels,
532+
block_layers=block_layers,
533+
stochastic_depth_prob=stochastic_depth_prob,
534+
squeeze_ratio=squeeze_ratio,
535+
expansion_ratio=expansion_ratio,
536+
normalization_fn=normalization_fn,
537+
activation_fn=activation_fn,
538+
head_dim=head_dim,
539+
mlp_ratio=mlp_ratio,
540+
mlp_dropout=mlp_dropout,
541+
attn_dropout=attn_dropout,
542+
partition_size=partition_size,
543+
input_size=input_size,
544+
**kwargs,
532545
)
533546

547+
if weights is not None:
548+
model.load_state_dict(weights.get_state_dict(progress=progress))
534549

535-
def max_vit_S_224(num_classes: int) -> MaxVit:
536-
return MaxVit(
537-
input_channels=3,
538-
stem_channels=64,
539-
input_size=(224, 224),
540-
out_classes=num_classes,
541-
block_channels=[96, 192, 384, 768],
542-
block_layers=[2, 2, 5, 2],
543-
stochastic_depth_prob=0.3,
544-
squeeze_ratio=0.25,
545-
expansion_ratio=4.0,
546-
normalization_fn=nn.BatchNorm2d,
547-
activation_fn=GeluWrapper,
548-
head_dim=32,
549-
mlp_ratio=2,
550-
mlp_dropout=0.0,
551-
attn_dropout=0.0,
552-
partition_size=7,
553-
)
550+
return model
554551

555552

556-
def max_vit_B_224(num_classes: int) -> MaxVit:
557-
return MaxVit(
558-
input_channels=3,
553+
@register_model(name="maxvit_t")
554+
def maxvit_t(*, weights: Optional[WeightsEnum] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
555+
return _maxvit(
559556
stem_channels=64,
560-
input_size=(224, 224),
561-
out_classes=num_classes,
562-
block_channels=[96, 192, 384, 768],
563-
block_layers=[2, 6, 14, 2],
564-
stochastic_depth_prob=0.4,
565-
squeeze_ratio=0.25,
566-
expansion_ratio=4.0,
567-
normalization_fn=nn.BatchNorm2d,
568-
activation_fn=GeluWrapper,
569-
head_dim=32,
570-
mlp_ratio=2,
571-
mlp_dropout=0.0,
572-
attn_dropout=0.0,
573-
partition_size=7,
574-
)
575-
576-
577-
def max_vit_L_224(num_classes: int) -> MaxVit:
578-
return MaxVit(
579-
input_channels=3,
580-
stem_channels=128,
581-
input_size=(224, 224),
582-
out_classes=num_classes,
583-
block_channels=[128, 256, 512, 1024],
584-
block_layers=[2, 6, 14, 2],
585-
stochastic_depth_prob=0.6,
557+
block_channels=[64, 128, 256, 512],
558+
block_layers=[2, 2, 5, 2],
559+
stochastic_depth_prob=0.2,
586560
squeeze_ratio=0.25,
587561
expansion_ratio=4.0,
588562
normalization_fn=nn.BatchNorm2d,
589-
activation_fn=GeluWrapper,
563+
activation_fn=nn.GELU,
590564
head_dim=32,
591565
mlp_ratio=2,
592566
mlp_dropout=0.0,
593567
attn_dropout=0.0,
594568
partition_size=7,
569+
weights=weights,
570+
progress=progress,
571+
**kwargs,
595572
)

0 commit comments

Comments
 (0)