|
10 | 10 | from .._api import WeightsEnum, Weights
|
11 | 11 | from .._meta import _IMAGENET_CATEGORIES
|
12 | 12 | from .._utils import handle_legacy_interface, _ovewrite_named_param
|
13 |
| -from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights |
| 13 | +from ..shufflenetv2 import ( |
| 14 | + ShuffleNet_V2_X0_5_Weights, |
| 15 | + ShuffleNet_V2_X1_0_Weights, |
| 16 | + ShuffleNet_V2_X1_5_Weights, |
| 17 | + ShuffleNet_V2_X2_0_Weights, |
| 18 | +) |
14 | 19 | from .utils import _fuse_modules, _replace_relu, quantize_model
|
15 | 20 |
|
16 | 21 |
|
17 | 22 | __all__ = [
|
18 | 23 | "QuantizableShuffleNetV2",
|
19 | 24 | "ShuffleNet_V2_X0_5_QuantizedWeights",
|
20 | 25 | "ShuffleNet_V2_X1_0_QuantizedWeights",
|
| 26 | + "ShuffleNet_V2_X1_5_QuantizedWeights", |
| 27 | + "ShuffleNet_V2_X2_0_QuantizedWeights", |
21 | 28 | "shufflenet_v2_x0_5",
|
22 | 29 | "shufflenet_v2_x1_0",
|
| 30 | + "shufflenet_v2_x1_5", |
| 31 | + "shufflenet_v2_x2_0", |
23 | 32 | ]
|
24 | 33 |
|
25 | 34 |
|
@@ -143,6 +152,42 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
|
143 | 152 | DEFAULT = IMAGENET1K_FBGEMM_V1
|
144 | 153 |
|
145 | 154 |
|
| 155 | +class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum): |
| 156 | + IMAGENET1K_FBGEMM_V1 = Weights( |
| 157 | + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth", |
| 158 | + transforms=partial(ImageClassification, crop_size=224, resize_size=232), |
| 159 | + meta={ |
| 160 | + **_COMMON_META, |
| 161 | + "recipe": "https://github.com/pytorch/vision/pull/5906", |
| 162 | + "num_params": 3503624, |
| 163 | + "unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1, |
| 164 | + "metrics": { |
| 165 | + "acc@1": 72.052, |
| 166 | + "acc@5": 90.700, |
| 167 | + }, |
| 168 | + }, |
| 169 | + ) |
| 170 | + DEFAULT = IMAGENET1K_FBGEMM_V1 |
| 171 | + |
| 172 | + |
| 173 | +class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum): |
| 174 | + IMAGENET1K_FBGEMM_V1 = Weights( |
| 175 | + url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth", |
| 176 | + transforms=partial(ImageClassification, crop_size=224, resize_size=232), |
| 177 | + meta={ |
| 178 | + **_COMMON_META, |
| 179 | + "recipe": "https://github.com/pytorch/vision/pull/5906", |
| 180 | + "num_params": 7393996, |
| 181 | + "unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1, |
| 182 | + "metrics": { |
| 183 | + "acc@1": 75.354, |
| 184 | + "acc@5": 92.488, |
| 185 | + }, |
| 186 | + }, |
| 187 | + ) |
| 188 | + DEFAULT = IMAGENET1K_FBGEMM_V1 |
| 189 | + |
| 190 | + |
146 | 191 | @handle_legacy_interface(
|
147 | 192 | weights=(
|
148 | 193 | "pretrained",
|
@@ -205,3 +250,51 @@ def shufflenet_v2_x1_0(
|
205 | 250 | return _shufflenetv2(
|
206 | 251 | [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
|
207 | 252 | )
|
| 253 | + |
| 254 | + |
| 255 | +def shufflenet_v2_x1_5( |
| 256 | + *, |
| 257 | + weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None, |
| 258 | + progress: bool = True, |
| 259 | + quantize: bool = False, |
| 260 | + **kwargs: Any, |
| 261 | +) -> QuantizableShuffleNetV2: |
| 262 | + """ |
| 263 | + Constructs a ShuffleNetV2 with 1.5x output channels, as described in |
| 264 | + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" |
| 265 | + <https://arxiv.org/abs/1807.11164>`_. |
| 266 | +
|
| 267 | + Args: |
| 268 | + weights (ShuffleNet_V2_X1_5_QuantizedWeights or ShuffleNet_V2_X1_5_Weights, optional): The pretrained |
| 269 | + weights for the model |
| 270 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 271 | + quantize (bool): If True, return a quantized version of the model |
| 272 | + """ |
| 273 | + weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights) |
| 274 | + return _shufflenetv2( |
| 275 | + [4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs |
| 276 | + ) |
| 277 | + |
| 278 | + |
| 279 | +def shufflenet_v2_x2_0( |
| 280 | + *, |
| 281 | + weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None, |
| 282 | + progress: bool = True, |
| 283 | + quantize: bool = False, |
| 284 | + **kwargs: Any, |
| 285 | +) -> QuantizableShuffleNetV2: |
| 286 | + """ |
| 287 | + Constructs a ShuffleNetV2 with 2.0x output channels, as described in |
| 288 | + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" |
| 289 | + <https://arxiv.org/abs/1807.11164>`_. |
| 290 | +
|
| 291 | + Args: |
| 292 | + weights (ShuffleNet_V2_X2_0_QuantizedWeights or ShuffleNet_V2_X2_0_Weights, optional): The pretrained |
| 293 | + weights for the model |
| 294 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 295 | + quantize (bool): If True, return a quantized version of the model |
| 296 | + """ |
| 297 | + weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights) |
| 298 | + return _shufflenetv2( |
| 299 | + [4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs |
| 300 | + ) |
0 commit comments