3
3
import torch
4
4
import torch .nn as nn
5
5
6
- __all__ = ['ShuffleNetV2' , 'shufflenetv2' ,
7
- 'shufflenetv2_x0_5' , 'shufflenetv2_x1_0' ,
8
- 'shufflenetv2_x1_5' , 'shufflenetv2_x2_0' ]
6
+ __all__ = ['ShuffleNetV2' , 'shufflenetv2_x0_5' , 'shufflenetv2_x1_0' , 'shufflenetv2_x1_5' , 'shufflenetv2_x2_0' ]
9
7
10
8
model_urls = {
11
9
'shufflenetv2_x0.5' :
@@ -85,16 +83,17 @@ def forward(self, x):
85
83
86
84
87
85
class ShuffleNetV2 (nn .Module ):
88
- def __init__ (self , num_classes = 1000 , width_mult = 1 ):
86
+ def __init__ (self , stages_repeats , stages_out_channels , num_classes = 1000 ):
89
87
super (ShuffleNetV2 , self ).__init__ ()
90
88
91
- try :
92
- self .stage_out_channels = self ._getStages (float (width_mult ))
93
- except KeyError :
94
- raise ValueError ('width_mult {} is not supported' .format (width_mult ))
89
+ if len (stages_repeats ) != 3 :
90
+ raise ValueError ('expected stages_repeats as list of 3 positive ints' )
91
+ if len (stages_out_channels ) != 5 :
92
+ raise ValueError ('expected stages_out_channels as list of 5 positive ints' )
93
+ self ._stage_out_channels = stages_out_channels
95
94
96
95
input_channels = 3
97
- output_channels = self .stage_out_channels [0 ]
96
+ output_channels = self ._stage_out_channels [0 ]
98
97
self .conv1 = nn .Sequential (
99
98
nn .Conv2d (input_channels , output_channels , 3 , 2 , 1 , bias = False ),
100
99
nn .BatchNorm2d (output_channels ),
@@ -105,16 +104,15 @@ def __init__(self, num_classes=1000, width_mult=1):
105
104
self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
106
105
107
106
stage_names = ['stage{}' .format (i ) for i in [2 , 3 , 4 ]]
108
- stage_repeats = [4 , 8 , 4 ]
109
107
for name , repeats , output_channels in zip (
110
- stage_names , stage_repeats , self .stage_out_channels [1 :]):
108
+ stage_names , stages_repeats , self ._stage_out_channels [1 :]):
111
109
seq = [InvertedResidual (input_channels , output_channels , 2 )]
112
110
for i in range (repeats - 1 ):
113
111
seq .append (InvertedResidual (output_channels , output_channels , 1 ))
114
112
setattr (self , name , nn .Sequential (* seq ))
115
113
input_channels = output_channels
116
114
117
- output_channels = self .stage_out_channels [- 1 ]
115
+ output_channels = self ._stage_out_channels [- 1 ]
118
116
self .conv5 = nn .Sequential (
119
117
nn .Conv2d (input_channels , output_channels , 1 , 1 , 0 , bias = False ),
120
118
nn .BatchNorm2d (output_channels ),
@@ -135,24 +133,22 @@ def forward(self, x):
135
133
return x
136
134
137
135
@staticmethod
138
- def _getStages ( mult ):
136
+ def getPaperParams ( width_mult ):
139
137
stages = {
140
138
'0.5' : [24 , 48 , 96 , 192 , 1024 ],
141
139
'1.0' : [24 , 116 , 232 , 464 , 1024 ],
142
140
'1.5' : [24 , 176 , 352 , 704 , 1024 ],
143
141
'2.0' : [24 , 244 , 488 , 976 , 2048 ],
144
142
}
145
- return stages [str (mult )]
143
+ stage_repeats = [4 , 8 , 4 ]
144
+ return (stage_repeats , stages [width_mult ])
146
145
147
146
148
- def shufflenetv2 (pretrained = False , num_classes = 1000 , width_mult = 1 , ** kwargs ):
149
- model = ShuffleNetV2 (num_classes = num_classes , width_mult = width_mult )
147
+ def _shufflenetv2 (pretrained = False , num_classes = 1000 , width_mult = '1.0' , ** kwargs ):
148
+ model = ShuffleNetV2 (* ShuffleNetV2 . getPaperParams ( width_mult ), num_classes = num_classes )
150
149
151
150
if pretrained :
152
- # change width_mult to float
153
- if isinstance (width_mult , int ):
154
- width_mult = float (width_mult )
155
- model_type = ('_' .join ([ShuffleNetV2 .__name__ , 'x' + str (width_mult )]))
151
+ model_type = ('_' .join ([ShuffleNetV2 .__name__ , 'x' + width_mult ]))
156
152
try :
157
153
model_url = model_urls [model_type .lower ()]
158
154
except KeyError :
@@ -165,16 +161,16 @@ def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
165
161
166
162
167
163
def shufflenetv2_x0_5 (pretrained = False , num_classes = 1000 , ** kwargs ):
168
- return shufflenetv2 (pretrained , num_classes , 0.5 )
164
+ return _shufflenetv2 (pretrained , num_classes , ' 0.5' )
169
165
170
166
171
167
def shufflenetv2_x1_0 (pretrained = False , num_classes = 1000 , ** kwargs ):
172
- return shufflenetv2 (pretrained , num_classes , 1 )
168
+ return _shufflenetv2 (pretrained , num_classes , '1.0' )
173
169
174
170
175
171
def shufflenetv2_x1_5 (pretrained = False , num_classes = 1000 , ** kwargs ):
176
- return shufflenetv2 (pretrained , num_classes , 1.5 )
172
+ return _shufflenetv2 (pretrained , num_classes , ' 1.5' )
177
173
178
174
179
175
def shufflenetv2_x2_0 (pretrained = False , num_classes = 1000 , ** kwargs ):
180
- return shufflenetv2 (pretrained , num_classes , 2 )
176
+ return _shufflenetv2 (pretrained , num_classes , '2.0' )
0 commit comments