@@ -84,13 +84,17 @@ def forward(self, x):
84
84
85
85
86
86
class ShuffleNetV2 (nn .Module ):
87
- def __init__ (self , stage_out_channels , num_classes = 1000 ):
87
+ def __init__ (self , stages_repeats , stages_out_channels , num_classes = 1000 ):
88
88
super (ShuffleNetV2 , self ).__init__ ()
89
89
90
- self .stage_out_channels = stage_out_channels
91
- input_channels = 3
92
- output_channels = self .stage_out_channels [0 ]
90
+ if len (stages_repeats ) != 3 :
91
+ raise ValueError ('expected stages_repeats as list of 3 positive ints' )
92
+ if len (stages_out_channels ) != 5 :
93
+ raise ValueError ('expected stages_out_channels as list of 5 positive ints' )
94
+ self ._stage_out_channels = stages_out_channels
93
95
96
+ input_channels = 3
97
+ output_channels = self ._stage_out_channels [0 ]
94
98
self .conv1 = nn .Sequential (
95
99
nn .Conv2d (input_channels , output_channels , 3 , 2 , 1 , bias = False ),
96
100
nn .BatchNorm2d (output_channels ),
@@ -101,16 +105,15 @@ def __init__(self, stage_out_channels, num_classes=1000):
101
105
self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
102
106
103
107
stage_names = ['stage{}' .format (i ) for i in [2 , 3 , 4 ]]
104
- stage_repeats = [4 , 8 , 4 ]
105
108
for name , repeats , output_channels in zip (
106
- stage_names , stage_repeats , self .stage_out_channels [1 :]):
109
+ stage_names , stages_repeats , self ._stage_out_channels [1 :]):
107
110
seq = [InvertedResidual (input_channels , output_channels , 2 )]
108
111
for i in range (repeats - 1 ):
109
112
seq .append (InvertedResidual (output_channels , output_channels , 1 ))
110
113
setattr (self , name , nn .Sequential (* seq ))
111
114
input_channels = output_channels
112
115
113
- output_channels = self .stage_out_channels [- 1 ]
116
+ output_channels = self ._stage_out_channels [- 1 ]
114
117
self .conv5 = nn .Sequential (
115
118
nn .Conv2d (input_channels , output_channels , 1 , 1 , 0 , bias = False ),
116
119
nn .BatchNorm2d (output_channels ),
@@ -131,8 +134,8 @@ def forward(self, x):
131
134
return x
132
135
133
136
134
- def _shufflenetv2 (arch , pretrained , progress , stage_out_channels , ** kwargs ):
135
- model = ShuffleNetV2 (stage_out_channels = stage_out_channels , ** kwargs )
137
+ def _shufflenetv2 (arch , pretrained , progress , * args , ** kwargs ):
138
+ model = ShuffleNetV2 (* args , ** kwargs )
136
139
137
140
if pretrained :
138
141
model_url = model_urls [arch ]
@@ -146,16 +149,20 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
146
149
147
150
148
151
def shufflenetv2_x0_5 (pretrained = False , progress = True , ** kwargs ):
149
- return _shufflenetv2 ('shufflenetv2_x0.5' , pretrained , progress , [24 , 48 , 96 , 192 , 1024 ], ** kwargs )
152
+ return _shufflenetv2 ('shufflenetv2_x0.5' , pretrained , progress ,
153
+ [4 , 8 , 4 ], [24 , 48 , 96 , 192 , 1024 ], ** kwargs )
150
154
151
155
152
156
def shufflenetv2_x1_0 (pretrained = False , progress = True , ** kwargs ):
153
- return _shufflenetv2 ('shufflenetv2_x1.0' , pretrained , progress , [24 , 116 , 232 , 464 , 1024 ], ** kwargs )
157
+ return _shufflenetv2 ('shufflenetv2_x1.0' , pretrained , progress ,
158
+ [4 , 8 , 4 ], [24 , 116 , 232 , 464 , 1024 ], ** kwargs )
154
159
155
160
156
161
def shufflenetv2_x1_5 (pretrained = False , progress = True , ** kwargs ):
157
- return _shufflenetv2 ('shufflenetv2_x1.5' , pretrained , progress , [24 , 176 , 352 , 704 , 1024 ], ** kwargs )
162
+ return _shufflenetv2 ('shufflenetv2_x1.5' , pretrained , progress ,
163
+ [4 , 8 , 4 ], [24 , 176 , 352 , 704 , 1024 ], ** kwargs )
158
164
159
165
160
166
def shufflenetv2_x2_0 (pretrained = False , progress = True , ** kwargs ):
161
- return _shufflenetv2 ('shufflenetv2_x2.0' , pretrained , progress , [24 , 244 , 488 , 976 , 2048 ], ** kwargs )
167
+ return _shufflenetv2 ('shufflenetv2_x2.0' , pretrained , progress ,
168
+ [4 , 8 , 4 ], [24 , 244 , 488 , 976 , 2048 ], ** kwargs )
0 commit comments