Skip to content

Commit d91e35f

Browse files
committed
Enhance ShufflenetV2
Class shufflenetv2 receives `stages_repeats` and `stages_out_channels` arguments.
1 parent 78ed423 commit d91e35f

File tree

1 file changed

+20
-24
lines changed

1 file changed

+20
-24
lines changed

torchvision/models/shufflenetv2.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import torch
44
import torch.nn as nn
55

6-
__all__ = ['ShuffleNetV2', 'shufflenetv2',
7-
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
8-
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
6+
__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
97

108
model_urls = {
119
'shufflenetv2_x0.5':
@@ -85,16 +83,17 @@ def forward(self, x):
8583

8684

8785
class ShuffleNetV2(nn.Module):
88-
def __init__(self, num_classes=1000, width_mult=1):
86+
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
8987
super(ShuffleNetV2, self).__init__()
9088

91-
try:
92-
self.stage_out_channels = self._getStages(float(width_mult))
93-
except KeyError:
94-
raise ValueError('width_mult {} is not supported'.format(width_mult))
89+
if len(stages_repeats) != 3:
90+
raise ValueError('expected stages_repeats as list of 3 positive ints')
91+
if len(stages_out_channels) != 5:
92+
raise ValueError('expected stages_out_channels as list of 5 positive ints')
93+
self._stage_out_channels = stages_out_channels
9594

9695
input_channels = 3
97-
output_channels = self.stage_out_channels[0]
96+
output_channels = self._stage_out_channels[0]
9897
self.conv1 = nn.Sequential(
9998
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
10099
nn.BatchNorm2d(output_channels),
@@ -105,16 +104,15 @@ def __init__(self, num_classes=1000, width_mult=1):
105104
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106105

107106
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
108-
stage_repeats = [4, 8, 4]
109107
for name, repeats, output_channels in zip(
110-
stage_names, stage_repeats, self.stage_out_channels[1:]):
108+
stage_names, stages_repeats, self._stage_out_channels[1:]):
111109
seq = [InvertedResidual(input_channels, output_channels, 2)]
112110
for i in range(repeats - 1):
113111
seq.append(InvertedResidual(output_channels, output_channels, 1))
114112
setattr(self, name, nn.Sequential(*seq))
115113
input_channels = output_channels
116114

117-
output_channels = self.stage_out_channels[-1]
115+
output_channels = self._stage_out_channels[-1]
118116
self.conv5 = nn.Sequential(
119117
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
120118
nn.BatchNorm2d(output_channels),
@@ -135,24 +133,22 @@ def forward(self, x):
135133
return x
136134

137135
@staticmethod
138-
def _getStages(mult):
136+
def getPaperParams(width_mult):
139137
stages = {
140138
'0.5': [24, 48, 96, 192, 1024],
141139
'1.0': [24, 116, 232, 464, 1024],
142140
'1.5': [24, 176, 352, 704, 1024],
143141
'2.0': [24, 244, 488, 976, 2048],
144142
}
145-
return stages[str(mult)]
143+
stage_repeats = [4, 8, 4]
144+
return (stage_repeats, stages[width_mult])
146145

147146

148-
def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
149-
model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult)
147+
def _shufflenetv2(pretrained=False, num_classes=1000, width_mult='1.0', **kwargs):
148+
model = ShuffleNetV2(*ShuffleNetV2.getPaperParams(width_mult), num_classes=num_classes)
150149

151150
if pretrained:
152-
# change width_mult to float
153-
if isinstance(width_mult, int):
154-
width_mult = float(width_mult)
155-
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
151+
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + width_mult]))
156152
try:
157153
model_url = model_urls[model_type.lower()]
158154
except KeyError:
@@ -165,16 +161,16 @@ def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
165161

166162

167163
def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs):
168-
return shufflenetv2(pretrained, num_classes, 0.5)
164+
return _shufflenetv2(pretrained, num_classes, '0.5')
169165

170166

171167
def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs):
172-
return shufflenetv2(pretrained, num_classes, 1)
168+
return _shufflenetv2(pretrained, num_classes, '1.0')
173169

174170

175171
def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs):
176-
return shufflenetv2(pretrained, num_classes, 1.5)
172+
return _shufflenetv2(pretrained, num_classes, '1.5')
177173

178174

179175
def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs):
180-
return shufflenetv2(pretrained, num_classes, 2)
176+
return _shufflenetv2(pretrained, num_classes, '2.0')

0 commit comments

Comments
 (0)