Skip to content

Commit 0564df4

Browse files
ekagra-ranjanfmassa
authored andcommitted
Refactoring of ShuffleNetV2 (#889)
* Minor refactoring of ShuffleNetV2 Added progress flag following #875. Further the following refactoring was also done: 1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string. 2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2` * removed `version` arg * Update shufflenetv2.py * Removed the try except block * Update shufflenetv2.py * Changed version from float to str * Replace `width_mult` with `stages_out_channels` Removes the need of `_getStages` function.
1 parent 78ed423 commit 0564df4

File tree

1 file changed

+20
-39
lines changed

1 file changed

+20
-39
lines changed

torchvision/models/shufflenetv2.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import torch
44
import torch.nn as nn
5+
from .utils import load_state_dict_from_url
56

6-
__all__ = ['ShuffleNetV2', 'shufflenetv2',
7-
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
8-
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
7+
__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
98

109
model_urls = {
1110
'shufflenetv2_x0.5':
@@ -85,16 +84,13 @@ def forward(self, x):
8584

8685

8786
class ShuffleNetV2(nn.Module):
88-
def __init__(self, num_classes=1000, width_mult=1):
87+
def __init__(self, stage_out_channels, num_classes=1000):
8988
super(ShuffleNetV2, self).__init__()
9089

91-
try:
92-
self.stage_out_channels = self._getStages(float(width_mult))
93-
except KeyError:
94-
raise ValueError('width_mult {} is not supported'.format(width_mult))
95-
90+
self.stage_out_channels = stage_out_channels
9691
input_channels = 3
9792
output_channels = self.stage_out_channels[0]
93+
9894
self.conv1 = nn.Sequential(
9995
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
10096
nn.BatchNorm2d(output_channels),
@@ -134,47 +130,32 @@ def forward(self, x):
134130
x = self.fc(x)
135131
return x
136132

137-
@staticmethod
138-
def _getStages(mult):
139-
stages = {
140-
'0.5': [24, 48, 96, 192, 1024],
141-
'1.0': [24, 116, 232, 464, 1024],
142-
'1.5': [24, 176, 352, 704, 1024],
143-
'2.0': [24, 244, 488, 976, 2048],
144-
}
145-
return stages[str(mult)]
146133

147-
148-
def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
149-
model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult)
134+
def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
135+
model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs)
150136

151137
if pretrained:
152-
# change width_mult to float
153-
if isinstance(width_mult, int):
154-
width_mult = float(width_mult)
155-
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
156-
try:
157-
model_url = model_urls[model_type.lower()]
158-
except KeyError:
159-
raise ValueError('model {} is not support'.format(model_type))
138+
model_url = model_urls[arch]
160139
if model_url is None:
161-
raise NotImplementedError('pretrained {} is not supported'.format(model_type))
162-
model.load_state_dict(torch.utils.model_zoo.load_url(model_url))
140+
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
141+
else:
142+
state_dict = load_state_dict_from_url(model_urls, progress=progress)
143+
model.load_state_dict(state_dict)
163144

164145
return model
165146

166147

167-
def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs):
168-
return shufflenetv2(pretrained, num_classes, 0.5)
148+
def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
149+
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs)
169150

170151

171-
def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs):
172-
return shufflenetv2(pretrained, num_classes, 1)
152+
def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
153+
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs)
173154

174155

175-
def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs):
176-
return shufflenetv2(pretrained, num_classes, 1.5)
156+
def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
157+
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs)
177158

178159

179-
def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs):
180-
return shufflenetv2(pretrained, num_classes, 2)
160+
def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
161+
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)

0 commit comments

Comments
 (0)