11import math
2- from typing import Any , Callable , List , OrderedDict , Sequence , Tuple
2+ from typing import Any , Callable , List , Optional , OrderedDict , Sequence , Tuple
33
44import numpy as np
55import torch
66import torch .nn .functional as F
77from torch import nn , Tensor
8+ from torchvision .models ._api import register_model , WeightsEnum
9+ from torchvision .models ._utils import _ovewrite_named_param
810from torchvision .ops .misc import Conv2dNormActivation , SqueezeExcitation
911from torchvision .ops .stochastic_depth import StochasticDepth
12+ from torchvision .utils import _log_api_usage_once
1013
1114
1215def get_relative_position_index (height : int , width : int ) -> torch .Tensor :
@@ -20,20 +23,6 @@ def get_relative_position_index(height: int, width: int) -> torch.Tensor:
2023 return relative_coords .sum (- 1 )
2124
2225
23- class GeluWrapper (nn .Module ):
24- """
25- Gelu wrapper to make it compatible with `ConvNormActivation2D` which passed inplace=True
26- to the activation function construction.
27- """
28-
29- def __init__ (self , ** kwargs ) -> None :
30- super ().__init__ ()
31- self ._op = F .gelu
32-
33- def forward (self , x : Tensor ) -> Tensor :
34- return self ._op (x )
35-
36-
3726class MBConv (nn .Module ):
3827 def __init__ (
3928 self ,
@@ -65,20 +54,28 @@ def __init__(
6554 _layers = OrderedDict ()
6655 _layers ["pre_norm" ] = normalization_fn (in_channels )
6756 _layers ["conv_a" ] = Conv2dNormActivation (
68- in_channels , mid_channels , 1 , 1 , 0 , activation_layer = activation_fn , norm_layer = normalization_fn
57+ in_channels ,
58+ mid_channels ,
59+ kernel_size = 1 ,
60+ stride = 1 ,
61+ padding = 0 ,
62+ activation_layer = activation_fn ,
63+ norm_layer = normalization_fn ,
64+ inplace = None ,
6965 )
7066 _layers ["conv_b" ] = Conv2dNormActivation (
7167 mid_channels ,
7268 mid_channels ,
73- 3 ,
74- stride ,
75- 1 ,
69+ kernel_size = 3 ,
70+ stride = stride ,
71+ padding = 1 ,
7672 activation_layer = activation_fn ,
7773 norm_layer = normalization_fn ,
7874 groups = mid_channels ,
75+ inplace = None ,
7976 )
8077 _layers ["squeeze_excitation" ] = SqueezeExcitation (mid_channels , sqz_channels )
81- _layers ["conv_c" ] = nn .Conv2d (in_channels = mid_channels , out_channels = out_channels , kernel_size = 1 , bias = False )
78+ _layers ["conv_c" ] = nn .Conv2d (in_channels = mid_channels , out_channels = out_channels , kernel_size = 1 , bias = True )
8279
8380 self .layers = nn .Sequential (_layers )
8481
@@ -116,14 +113,13 @@ def __init__(
116113 # initialize with truncated normal the bias
117114 self .positional_bias .data .normal_ (mean = 0 , std = 0.02 )
118115
119- def _get_relative_positional_bias (self ) -> torch .Tensor :
116+ def get_relative_positional_bias (self ) -> torch .Tensor :
120117 bias_index = self .relative_position_index .view (- 1 ) # type: ignore
121118 relative_bias = self .positional_bias [bias_index ].view (self .max_seq_len , self .max_seq_len , - 1 ) # type: ignore
122119 relative_bias = relative_bias .permute (2 , 0 , 1 ).contiguous ()
123120 return relative_bias .unsqueeze (0 )
124121
125122 def forward (self , x : Tensor ) -> Tensor :
126- # X, Y and stand for X-axis group dim, Y-axis group dim
127123 B , G , P , D = x .shape
128124 H , DH = self .n_heads , self .head_dim
129125
@@ -135,9 +131,8 @@ def forward(self, x: Tensor) -> Tensor:
135131 v = v .reshape (B , G , P , H , DH ).permute (0 , 1 , 3 , 2 , 4 )
136132
137133 k = k * self .scale_factor
138- # X, Y and stand for X-axis group dim, Y-axis group dim
139134 dot_prod = torch .einsum ("B G H I D, B G H J D -> B G H I J" , q , k )
140- pos_bias = self ._get_relative_positional_bias ()
135+ pos_bias = self .get_relative_positional_bias ()
141136
142137 dot_prod = F .softmax (dot_prod + pos_bias , dim = - 1 )
143138
@@ -204,34 +199,6 @@ def forward(self, x: Tensor) -> Tensor:
204199 return x
205200
206201
207- class MLP (nn .Module ):
208- def __init__ (
209- self ,
210- in_dim : int ,
211- hidden_dim : int ,
212- activation_fn : Callable [..., nn .Module ],
213- normalization_fn : Callable [..., nn .Module ],
214- dropout : float ,
215- ) -> None :
216- super ().__init__ ()
217- self .in_dim = in_dim
218- self .hidden_dim = hidden_dim
219- self .activation_fn = activation_fn
220- self .normalization_fn = normalization_fn
221- self .dropout = dropout
222-
223- self .layers = nn .Sequential (
224- self .normalization_fn (in_dim ),
225- nn .Linear (in_dim , hidden_dim ),
226- self .activation_fn (),
227- nn .Linear (hidden_dim , in_dim ),
228- nn .Dropout (dropout ),
229- )
230-
231- def forward (self , x : Tensor ) -> Tensor :
232- return x + self .layers (x )
233-
234-
235202class PartitionAttentionLayer (nn .Module ):
236203 def __init__ (
237204 self ,
@@ -282,16 +249,23 @@ def __init__(
282249 nn .Dropout (attn_dropout ),
283250 )
284251
285- self .mlp_layer = MLP (in_channels , in_channels * mlp_ratio , activation_fn , normalization_fn , mlp_dropout )
252+ # pre-normalization similar to transformer layers
253+ self .mlp_layer = nn .Sequential (
254+ nn .LayerNorm (in_channels ),
255+ nn .Linear (in_channels , in_channels * mlp_ratio ),
256+ activation_fn (),
257+ nn .Linear (in_channels * mlp_ratio , in_channels ),
258+ nn .Dropout (mlp_dropout ),
259+ )
286260
287261 # layer scale factors
288262 self .attn_layer_scale = nn .parameter .Parameter (torch .ones (in_channels ) * 1e-6 )
289263 self .mlp_layer_scale = nn .parameter .Parameter (torch .ones (in_channels ) * 1e-6 )
290264
291265 def forward (self , x : Tensor ) -> Tensor :
292266 x = self .partition_op (x )
293- x = self .attn_layer (x ) * self .attn_layer_scale
294- x = self .mlp_layer (x ) * self .mlp_layer_scale
267+ x = x + self .attn_layer (x ) * self .attn_layer_scale
268+ x = x + self .mlp_layer (x ) * self .mlp_layer_scale
295269 x = self .departition_op (x )
296270 return x
297271
@@ -386,9 +360,8 @@ def __init__(
386360 p_stochastic : List [float ],
387361 ) -> None :
388362 super ().__init__ ()
389- assert (
390- len (p_stochastic ) == n_layers
391- ), f"p_stochastic must have length n_layers={ n_layers } , got p_stochastic={ p_stochastic } ."
363+ if not len (p_stochastic ) == n_layers :
364+ raise ValueError (f"p_stochastic must have length n_layers={ n_layers } , got p_stochastic={ p_stochastic } ." )
392365
393366 self .layers = nn .ModuleList ()
394367 # account for the first stride of the first layer
@@ -424,11 +397,12 @@ def forward(self, x: Tensor) -> Tensor:
424397class MaxVit (nn .Module ):
425398 def __init__ (
426399 self ,
400+ # input size parameters
401+ input_size : Tuple [int , int ],
427402 # stem and task parameters
428403 input_channels : int ,
429404 stem_channels : int ,
430- input_size : Tuple [int , int ],
431- out_classes : int ,
405+ num_classes : int ,
432406 # block parameters
433407 block_channels : List [int ],
434408 block_layers : List [int ],
@@ -450,6 +424,7 @@ def __init__(
450424 partition_size : int ,
451425 ) -> None :
452426 super ().__init__ ()
427+ _log_api_usage_once (self )
453428
454429 # stem
455430 self .stem = nn .Sequential (
@@ -500,7 +475,7 @@ def __init__(
500475 self .classifier = nn .Sequential (
501476 nn .AdaptiveAvgPool2d (1 ),
502477 nn .Flatten (),
503- nn .Linear (block_channels [- 1 ], out_classes , bias = False ),
478+ nn .Linear (block_channels [- 1 ], num_classes , bias = False ),
504479 )
505480
506481 def forward (self , x : Tensor ) -> Tensor :
@@ -511,85 +486,87 @@ def forward(self, x: Tensor) -> Tensor:
511486 return x
512487
513488
514- def max_vit_T_224 (num_classes : int ) -> MaxVit :
515- return MaxVit (
516- input_channels = 3 ,
517- stem_channels = 64 ,
518- input_size = (224 , 224 ),
519- out_classes = num_classes ,
520- block_channels = [64 , 128 , 256 , 512 ],
521- block_layers = [2 , 2 , 5 , 2 ],
522- stochastic_depth_prob = 0.2 ,
523- squeeze_ratio = 0.25 ,
524- expansion_ratio = 4.0 ,
525- normalization_fn = nn .BatchNorm2d ,
526- activation_fn = GeluWrapper ,
527- head_dim = 32 ,
528- mlp_ratio = 2 ,
529- mlp_dropout = 0.0 ,
530- attn_dropout = 0.0 ,
531- partition_size = 7 ,
489+ def _maxvit (
490+ # stem and task parameters
491+ stem_channels : int ,
492+ num_classes : int ,
493+ # block parameters
494+ block_channels : List [int ],
495+ block_layers : List [int ],
496+ stochastic_depth_prob : float ,
497+ # conv parameters
498+ squeeze_ratio : float ,
499+ expansion_ratio : float ,
500+ # conv + transformer parameters
501+ # normalization_fn is applied only to the conv layers
502+ # activation_fn is applied both to conv and transformer layers
503+ normalization_fn : Callable [..., nn .Module ],
504+ activation_fn : Callable [..., nn .Module ],
505+ # transformer parameters
506+ head_dim : int ,
507+ mlp_ratio : int ,
508+ mlp_dropout : float ,
509+ attn_dropout : float ,
510+ # partitioning parameters
511+ partition_size : int ,
512+ # Weights API
513+ weights : Optional [WeightsEnum ],
514+ progress : bool ,
515+ # kwargs,
516+ ** kwargs ,
517+ ) -> MaxVit :
518+ if weights is not None :
519+ _ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
520+ assert weights .meta ["min_size" ][0 ] == weights .meta ["min_size" ][1 ]
521+ _ovewrite_named_param (kwargs , "input_size" , weights .meta ["min_size" ][0 ])
522+ _ovewrite_named_param (kwargs , "input_channels" , weights .meta ["input_channels" ])
523+
524+ input_size = kwargs .pop ("input_size" , (224 , 224 ))
525+ input_channels = kwargs .pop ("input_channels" , 3 )
526+
527+ model = MaxVit (
528+ input_channels = input_channels ,
529+ stem_channels = stem_channels ,
530+ num_classes = num_classes ,
531+ block_channels = block_channels ,
532+ block_layers = block_layers ,
533+ stochastic_depth_prob = stochastic_depth_prob ,
534+ squeeze_ratio = squeeze_ratio ,
535+ expansion_ratio = expansion_ratio ,
536+ normalization_fn = normalization_fn ,
537+ activation_fn = activation_fn ,
538+ head_dim = head_dim ,
539+ mlp_ratio = mlp_ratio ,
540+ mlp_dropout = mlp_dropout ,
541+ attn_dropout = attn_dropout ,
542+ partition_size = partition_size ,
543+ input_size = input_size ,
544+ ** kwargs ,
532545 )
533546
547+ if weights is not None :
548+ model .load_state_dict (weights .get_state_dict (progress = progress ))
534549
535- def max_vit_S_224 (num_classes : int ) -> MaxVit :
536- return MaxVit (
537- input_channels = 3 ,
538- stem_channels = 64 ,
539- input_size = (224 , 224 ),
540- out_classes = num_classes ,
541- block_channels = [96 , 192 , 384 , 768 ],
542- block_layers = [2 , 2 , 5 , 2 ],
543- stochastic_depth_prob = 0.3 ,
544- squeeze_ratio = 0.25 ,
545- expansion_ratio = 4.0 ,
546- normalization_fn = nn .BatchNorm2d ,
547- activation_fn = GeluWrapper ,
548- head_dim = 32 ,
549- mlp_ratio = 2 ,
550- mlp_dropout = 0.0 ,
551- attn_dropout = 0.0 ,
552- partition_size = 7 ,
553- )
550+ return model
554551
555552
556- def max_vit_B_224 ( num_classes : int ) -> MaxVit :
557- return MaxVit (
558- input_channels = 3 ,
553+ @ register_model ( name = "maxvit_t" )
554+ def maxvit_t ( * , weights : Optional [ WeightsEnum ] = None , progress : bool = True , ** kwargs : Any ) -> MaxVit :
555+ return _maxvit (
559556 stem_channels = 64 ,
560- input_size = (224 , 224 ),
561- out_classes = num_classes ,
562- block_channels = [96 , 192 , 384 , 768 ],
563- block_layers = [2 , 6 , 14 , 2 ],
564- stochastic_depth_prob = 0.4 ,
565- squeeze_ratio = 0.25 ,
566- expansion_ratio = 4.0 ,
567- normalization_fn = nn .BatchNorm2d ,
568- activation_fn = GeluWrapper ,
569- head_dim = 32 ,
570- mlp_ratio = 2 ,
571- mlp_dropout = 0.0 ,
572- attn_dropout = 0.0 ,
573- partition_size = 7 ,
574- )
575-
576-
577- def max_vit_L_224 (num_classes : int ) -> MaxVit :
578- return MaxVit (
579- input_channels = 3 ,
580- stem_channels = 128 ,
581- input_size = (224 , 224 ),
582- out_classes = num_classes ,
583- block_channels = [128 , 256 , 512 , 1024 ],
584- block_layers = [2 , 6 , 14 , 2 ],
585- stochastic_depth_prob = 0.6 ,
557+ block_channels = [64 , 128 , 256 , 512 ],
558+ block_layers = [2 , 2 , 5 , 2 ],
559+ stochastic_depth_prob = 0.2 ,
586560 squeeze_ratio = 0.25 ,
587561 expansion_ratio = 4.0 ,
588562 normalization_fn = nn .BatchNorm2d ,
589- activation_fn = GeluWrapper ,
563+ activation_fn = nn . GELU ,
590564 head_dim = 32 ,
591565 mlp_ratio = 2 ,
592566 mlp_dropout = 0.0 ,
593567 attn_dropout = 0.0 ,
594568 partition_size = 7 ,
569+ weights = weights ,
570+ progress = progress ,
571+ ** kwargs ,
595572 )
0 commit comments