|
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import torch.nn as nn
|
| 5 | +from .utils import load_state_dict_from_url |
5 | 6 |
|
6 |
| -__all__ = ['ShuffleNetV2', 'shufflenetv2', |
7 |
| - 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', |
8 |
| - 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] |
| 7 | +__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] |
9 | 8 |
|
10 | 9 | model_urls = {
|
11 | 10 | 'shufflenetv2_x0.5':
|
@@ -85,16 +84,13 @@ def forward(self, x):
|
85 | 84 |
|
86 | 85 |
|
87 | 86 | class ShuffleNetV2(nn.Module):
|
88 |
| - def __init__(self, num_classes=1000, width_mult=1): |
| 87 | + def __init__(self, stage_out_channels, num_classes=1000): |
89 | 88 | super(ShuffleNetV2, self).__init__()
|
90 | 89 |
|
91 |
| - try: |
92 |
| - self.stage_out_channels = self._getStages(float(width_mult)) |
93 |
| - except KeyError: |
94 |
| - raise ValueError('width_mult {} is not supported'.format(width_mult)) |
95 |
| - |
| 90 | + self.stage_out_channels = stage_out_channels |
96 | 91 | input_channels = 3
|
97 | 92 | output_channels = self.stage_out_channels[0]
|
| 93 | + |
98 | 94 | self.conv1 = nn.Sequential(
|
99 | 95 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
100 | 96 | nn.BatchNorm2d(output_channels),
|
@@ -134,47 +130,32 @@ def forward(self, x):
|
134 | 130 | x = self.fc(x)
|
135 | 131 | return x
|
136 | 132 |
|
137 |
| - @staticmethod |
138 |
| - def _getStages(mult): |
139 |
| - stages = { |
140 |
| - '0.5': [24, 48, 96, 192, 1024], |
141 |
| - '1.0': [24, 116, 232, 464, 1024], |
142 |
| - '1.5': [24, 176, 352, 704, 1024], |
143 |
| - '2.0': [24, 244, 488, 976, 2048], |
144 |
| - } |
145 |
| - return stages[str(mult)] |
146 | 133 |
|
147 |
| - |
148 |
| -def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs): |
149 |
| - model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult) |
| 134 | +def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): |
| 135 | + model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs) |
150 | 136 |
|
151 | 137 | if pretrained:
|
152 |
| - # change width_mult to float |
153 |
| - if isinstance(width_mult, int): |
154 |
| - width_mult = float(width_mult) |
155 |
| - model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)])) |
156 |
| - try: |
157 |
| - model_url = model_urls[model_type.lower()] |
158 |
| - except KeyError: |
159 |
| - raise ValueError('model {} is not support'.format(model_type)) |
| 138 | + model_url = model_urls[arch] |
160 | 139 | if model_url is None:
|
161 |
| - raise NotImplementedError('pretrained {} is not supported'.format(model_type)) |
162 |
| - model.load_state_dict(torch.utils.model_zoo.load_url(model_url)) |
| 140 | + raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) |
| 141 | + else: |
| 142 | + state_dict = load_state_dict_from_url(model_urls, progress=progress) |
| 143 | + model.load_state_dict(state_dict) |
163 | 144 |
|
164 | 145 | return model
|
165 | 146 |
|
166 | 147 |
|
167 |
| -def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs): |
168 |
| - return shufflenetv2(pretrained, num_classes, 0.5) |
| 148 | +def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): |
| 149 | + return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs) |
169 | 150 |
|
170 | 151 |
|
171 |
| -def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs): |
172 |
| - return shufflenetv2(pretrained, num_classes, 1) |
| 152 | +def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): |
| 153 | + return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs) |
173 | 154 |
|
174 | 155 |
|
175 |
| -def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs): |
176 |
| - return shufflenetv2(pretrained, num_classes, 1.5) |
| 156 | +def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): |
| 157 | + return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs) |
177 | 158 |
|
178 | 159 |
|
179 |
| -def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs): |
180 |
| - return shufflenetv2(pretrained, num_classes, 2) |
| 160 | +def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): |
| 161 | + return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs) |
0 commit comments