Skip to content

Commit 43ab2fe

Browse files
barrhfmassa
authored andcommitted
Enhance ShufflenetV2 (#892)
* Enhance ShufflenetV2 Class shufflenetv2 receives `stages_repeats` and `stages_out_channels` arguments. * remove explicit num_classes argument from utility functions
1 parent dc3ac29 commit 43ab2fe

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

torchvision/models/shufflenetv2.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,17 @@ def forward(self, x):
8484

8585

8686
class ShuffleNetV2(nn.Module):
87-
def __init__(self, stage_out_channels, num_classes=1000):
87+
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
8888
super(ShuffleNetV2, self).__init__()
8989

90-
self.stage_out_channels = stage_out_channels
91-
input_channels = 3
92-
output_channels = self.stage_out_channels[0]
90+
if len(stages_repeats) != 3:
91+
raise ValueError('expected stages_repeats as list of 3 positive ints')
92+
if len(stages_out_channels) != 5:
93+
raise ValueError('expected stages_out_channels as list of 5 positive ints')
94+
self._stage_out_channels = stages_out_channels
9395

96+
input_channels = 3
97+
output_channels = self._stage_out_channels[0]
9498
self.conv1 = nn.Sequential(
9599
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
96100
nn.BatchNorm2d(output_channels),
@@ -101,16 +105,15 @@ def __init__(self, stage_out_channels, num_classes=1000):
101105
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
102106

103107
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
104-
stage_repeats = [4, 8, 4]
105108
for name, repeats, output_channels in zip(
106-
stage_names, stage_repeats, self.stage_out_channels[1:]):
109+
stage_names, stages_repeats, self._stage_out_channels[1:]):
107110
seq = [InvertedResidual(input_channels, output_channels, 2)]
108111
for i in range(repeats - 1):
109112
seq.append(InvertedResidual(output_channels, output_channels, 1))
110113
setattr(self, name, nn.Sequential(*seq))
111114
input_channels = output_channels
112115

113-
output_channels = self.stage_out_channels[-1]
116+
output_channels = self._stage_out_channels[-1]
114117
self.conv5 = nn.Sequential(
115118
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
116119
nn.BatchNorm2d(output_channels),
@@ -131,8 +134,8 @@ def forward(self, x):
131134
return x
132135

133136

134-
def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
135-
model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs)
137+
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
138+
model = ShuffleNetV2(*args, **kwargs)
136139

137140
if pretrained:
138141
model_url = model_urls[arch]
@@ -146,16 +149,20 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
146149

147150

148151
def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
149-
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs)
152+
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
153+
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
150154

151155

152156
def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
153-
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs)
157+
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
158+
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
154159

155160

156161
def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
157-
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs)
162+
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
163+
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
158164

159165

160166
def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
161-
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)
167+
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
168+
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)

0 commit comments

Comments
 (0)