From 22b9b12b7cb1c8fee7b6b07c8cf1f6af78e36666 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Dec 2020 22:39:34 +0000 Subject: [PATCH 01/23] partial implementation network architecture --- torchvision/models/mobilenetv3.py | 117 ++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 torchvision/models/mobilenetv3.py diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py new file mode 100644 index 00000000000..a9b1a97637d --- /dev/null +++ b/torchvision/models/mobilenetv3.py @@ -0,0 +1,117 @@ +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Optional + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class _InplaceActivation(nn.Module): + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def extra_repr(self) -> str: + return 'inplace=True' if self.inplace else '' + + +def hard_sigmoid(x: Tensor, inplace: bool = False) -> Tensor: + return F.relu6(x + 3.0, inplace=inplace) / 6.0 + + +class HardSigmoid(_InplaceActivation): + + def forward(self, input: Tensor) -> Tensor: + return hard_sigmoid(input, inplace=self.inplace) + + +def hard_swish(x: Tensor, inplace: bool = False) -> Tensor: + return x * hard_swish(x, inplace=inplace) + + +class HardSwish(_InplaceActivation): + + def forward(self, input: Tensor) -> Tensor: + return hard_swish(input, inplace=self.inplace) + + +class SqueezeExcitation(nn.Module): + + def __init__(self, input_channels: int, squeeze_factor: int): + super().__init__() + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) + self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + + def forward(self, input: Tensor) -> Tensor: + scale = F.adaptive_avg_pool2d(input, 1) + scale = self.fc1(scale) + scale = F.relu(scale, inplace=True) + scale = self.fc2(scale) + scale = hard_sigmoid(scale, inplace=True) + return scale * input + + +class InvertedResidualBottleneck(nn.Module): + + def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, + se_block: bool, activation: str, stride: int): + super().__init__() + self.shortcut = stride == 1 and input_channels == output_channels + + self.block = nn.Sequential() + if expanded_channels != input_channels: + self.block.add_module("expand_conv", nn.Conv2d(input_channels, expanded_channels, 1, bias=False)) + self._add_bn_act("expand", expanded_channels, activation) + + self.block.add_module("depthwise_conv", nn.Conv2d(expanded_channels, expanded_channels, kernel, stride=stride, + padding=(kernel - 1) // 2, groups=expanded_channels, + bias=False)) + self._add_bn_act("depthwise", expanded_channels, activation) + + if se_block: + self.block.add_module("squeeze_excitation", SqueezeExcitation(expanded_channels, 4)) + + self.block.add_module("project_conv", nn.Conv2d(expanded_channels, output_channels, 1, bias=False)) + self._add_bn_act("project", expanded_channels, None) + + def _add_bn_act(self, block_name: str, channels: int, activation: Optional[str]): + self.block.add_module("{}_bn".format(block_name), nn.BatchNorm2d(channels, momentum=0.01, eps=0.001)) + if activation == "RE": + self.block.add_module("{}_act".format(block_name), nn.ReLU(inplace=True)) + elif activation == "HS": + self.block.add_module("{}_act".format(block_name), HardSwish(inplace=True)) + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.shortcut: + result += input + return result + + +class MobileNetV3(nn.Module): + + def __init__(self, + + num_classes: int = 1000): + + + #TODO: initialize weights From 0cff6fb8c76fef100a7c6a1e1e6773b22a51ae13 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 01:24:16 +0000 Subject: [PATCH 02/23] Simplify implementation and adding blocks. --- torchvision/models/mobilenetv3.py | 141 +++++++++++++++++++++++------- 1 file changed, 107 insertions(+), 34 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index a9b1a97637d..9ad882496d7 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,9 +1,10 @@ +from functools import partial from torch import nn, Tensor from torch.nn import functional as F -from typing import Optional +from typing import Callable, List, Optional -def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: +def _make_divisible(v: float, divisor: int = 8, min_value: Optional[int] = None) -> int: """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 @@ -55,9 +56,9 @@ def forward(self, input: Tensor) -> Tensor: class SqueezeExcitation(nn.Module): - def __init__(self, input_channels: int, squeeze_factor: int): + def __init__(self, input_channels: int, squeeze_factor: int = 4): super().__init__() - squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) + squeeze_channels = _make_divisible(input_channels // squeeze_factor) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) @@ -70,48 +71,120 @@ def forward(self, input: Tensor) -> Tensor: return scale * input -class InvertedResidualBottleneck(nn.Module): +class InvertedResidual(nn.Module): def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, - se_block: bool, activation: str, stride: int): + use_se: bool, use_hs: bool, stride: int, norm_layer: Callable[..., nn.Module]): super().__init__() - self.shortcut = stride == 1 and input_channels == output_channels + assert stride in [1, 2] - self.block = nn.Sequential() - if expanded_channels != input_channels: - self.block.add_module("expand_conv", nn.Conv2d(input_channels, expanded_channels, 1, bias=False)) - self._add_bn_act("expand", expanded_channels, activation) - - self.block.add_module("depthwise_conv", nn.Conv2d(expanded_channels, expanded_channels, kernel, stride=stride, - padding=(kernel - 1) // 2, groups=expanded_channels, - bias=False)) - self._add_bn_act("depthwise", expanded_channels, activation) - - if se_block: - self.block.add_module("squeeze_excitation", SqueezeExcitation(expanded_channels, 4)) - - self.block.add_module("project_conv", nn.Conv2d(expanded_channels, output_channels, 1, bias=False)) - self._add_bn_act("project", expanded_channels, None) + self.use_res_connect = stride == 1 and input_channels == output_channels - def _add_bn_act(self, block_name: str, channels: int, activation: Optional[str]): - self.block.add_module("{}_bn".format(block_name), nn.BatchNorm2d(channels, momentum=0.01, eps=0.001)) - if activation == "RE": - self.block.add_module("{}_act".format(block_name), nn.ReLU(inplace=True)) - elif activation == "HS": - self.block.add_module("{}_act".format(block_name), HardSwish(inplace=True)) + layers: List[nn.Module] = [] + # expand + if expanded_channels != input_channels: + layers.extend([ + nn.Conv2d(input_channels, expanded_channels, 1, bias=False), + norm_layer(expanded_channels), + HardSwish(inplace=True) if use_hs else nn.ReLU(inplace=True), + ]) + + # depthwise + layers.extend([ + nn.Conv2d(expanded_channels, expanded_channels, kernel, stride=stride, padding=(kernel - 1) // 2, + groups=expanded_channels, bias=False), + norm_layer(expanded_channels), + HardSwish(inplace=True) if use_hs else nn.ReLU(inplace=True), + ]) + if use_se: + layers.append(SqueezeExcitation(expanded_channels)) + + # project + layers.extend([ + nn.Conv2d(expanded_channels, output_channels, 1, bias=False), + norm_layer(expanded_channels), + ]) + + self.block = nn.Sequential(*layers) def forward(self, input: Tensor) -> Tensor: result = self.block(input) - if self.shortcut: + if self.use_res_connect: result += input return result class MobileNetV3(nn.Module): - def __init__(self, - - num_classes: int = 1000): - + def __init__( + self, + inverted_residual_setting: List[List[int]], + last_channel: int, + num_classes: int = 1000, + blocks: Optional[List[Callable[..., nn.Module]]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super().__init__() - #TODO: initialize weights + if blocks is None: + blocks = [SqueezeExcitation, InvertedResidual] + se_layer, bottleneck_layer = blocks + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + + layers: List[nn.Module] = [ + + ] + + + + pass + # TODO: initialize weights + + +def mobilenetv3(mode: str = "large", width_mult: float = 1.0): + if mode == "large": + inverted_residual_setting = [ + # in, kernel, exp, out, use_se, use_hs, stride + [16, 3, 16, 16, 0, 0, 1], + [16, 3, 64, 24, 0, 0, 2], + [24, 3, 72, 24, 0, 0, 1], + [24, 5, 72, 40, 1, 0, 2], + [40, 5, 120, 40, 1, 0, 1], + [40, 5, 120, 40, 1, 0, 1], + [40, 3, 240, 80, 0, 1, 2], + [80, 3, 200, 80, 0, 1, 1], + [80, 3, 184, 80, 0, 1, 1], + [80, 3, 184, 80, 0, 1, 1], + [80, 3, 480, 112, 1, 1, 1], + [112, 3, 672, 112, 1, 1, 1], + [112, 5, 672, 160, 1, 1, 2], + [160, 5, 960, 160, 1, 1, 1], + [160, 5, 960, 160, 1, 1, 1], + ] + last_channel = 1280 + else: + inverted_residual_setting = [ + # in, kernel, exp, out, use_se, use_hs, stride + [16, 3, 16, 16, 1, 0, 2], + [16, 3, 72, 24, 0, 0, 2], + [24, 3, 88, 24, 0, 0, 1], + [24, 5, 96, 40, 1, 1, 2], + [40, 5, 240, 40, 1, 1, 1], + [40, 5, 240, 40, 1, 1, 1], + [40, 5, 120, 48, 1, 1, 1], + [48, 5, 144, 48, 1, 1, 1], + [48, 5, 288, 96, 1, 1, 2], + [96, 5, 576, 96, 1, 1, 1], + [96, 5, 576, 96, 1, 1, 1], + ] + last_channel = 1024 + + # apply multipler on: in, exp, out columns + for row in inverted_residual_setting: + for id in (0, 2, 3): + row[id] = _make_divisible(row[id] * width_mult) + last_channel = _make_divisible(last_channel * width_mult) + + return MobileNetV3(inverted_residual_setting, last_channel) From fd54fdfda49d6627401a79f9f6c32c62a5a6f9fe Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 12:22:26 +0000 Subject: [PATCH 03/23] Refactoring the code to make it more readable. --- torchvision/models/mobilenetv3.py | 107 +++++++++++++++--------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 9ad882496d7..3382a057d36 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -71,38 +71,49 @@ def forward(self, input: Tensor) -> Tensor: return scale * input +class InvertedResidualConfig: + def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool, + activation: str, stride: int, width_mult: float): + self.input_channels = _make_divisible(input_channels * width_mult) + self.kernel = kernel + self.expanded_channels = _make_divisible(expanded_channels * width_mult) + self.output_channels = _make_divisible(output_channels * width_mult) + self.use_se = use_se + self.use_hs = activation == "HS" + self.stride = stride + + class InvertedResidual(nn.Module): - def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, - use_se: bool, use_hs: bool, stride: int, norm_layer: Callable[..., nn.Module]): + def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): super().__init__() - assert stride in [1, 2] + assert cnf.stride in [1, 2] - self.use_res_connect = stride == 1 and input_channels == output_channels + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels layers: List[nn.Module] = [] # expand - if expanded_channels != input_channels: + if cnf.expanded_channels != cnf.input_channels: layers.extend([ - nn.Conv2d(input_channels, expanded_channels, 1, bias=False), - norm_layer(expanded_channels), - HardSwish(inplace=True) if use_hs else nn.ReLU(inplace=True), + nn.Conv2d(cnf.input_channels, cnf.expanded_channels, 1, bias=False), + norm_layer(cnf.expanded_channels), + HardSwish(inplace=True) if cnf.use_hs else nn.ReLU(inplace=True), ]) # depthwise layers.extend([ - nn.Conv2d(expanded_channels, expanded_channels, kernel, stride=stride, padding=(kernel - 1) // 2, - groups=expanded_channels, bias=False), - norm_layer(expanded_channels), - HardSwish(inplace=True) if use_hs else nn.ReLU(inplace=True), + nn.Conv2d(cnf.expanded_channels, cnf.expanded_channels, cnf.kernel, stride=cnf.stride, + padding=(cnf.kernel - 1) // 2, groups=cnf.expanded_channels, bias=False), + norm_layer(cnf.expanded_channels), + HardSwish(inplace=True) if cnf.use_hs else nn.ReLU(inplace=True), ]) - if use_se: - layers.append(SqueezeExcitation(expanded_channels)) + if cnf.use_se: + layers.append(SqueezeExcitation(cnf.expanded_channels)) # project layers.extend([ - nn.Conv2d(expanded_channels, output_channels, 1, bias=False), - norm_layer(expanded_channels), + nn.Conv2d(cnf.expanded_channels, cnf.output_channels, 1, bias=False), + norm_layer(cnf.expanded_channels), ]) self.block = nn.Sequential(*layers) @@ -118,7 +129,7 @@ class MobileNetV3(nn.Module): def __init__( self, - inverted_residual_setting: List[List[int]], + inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, num_classes: int = 1000, blocks: Optional[List[Callable[..., nn.Module]]] = None, @@ -137,54 +148,46 @@ def __init__( ] - - pass # TODO: initialize weights def mobilenetv3(mode: str = "large", width_mult: float = 1.0): + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) if mode == "large": inverted_residual_setting = [ - # in, kernel, exp, out, use_se, use_hs, stride - [16, 3, 16, 16, 0, 0, 1], - [16, 3, 64, 24, 0, 0, 2], - [24, 3, 72, 24, 0, 0, 1], - [24, 5, 72, 40, 1, 0, 2], - [40, 5, 120, 40, 1, 0, 1], - [40, 5, 120, 40, 1, 0, 1], - [40, 3, 240, 80, 0, 1, 2], - [80, 3, 200, 80, 0, 1, 1], - [80, 3, 184, 80, 0, 1, 1], - [80, 3, 184, 80, 0, 1, 1], - [80, 3, 480, 112, 1, 1, 1], - [112, 3, 672, 112, 1, 1, 1], - [112, 5, 672, 160, 1, 1, 2], - [160, 5, 960, 160, 1, 1, 1], - [160, 5, 960, 160, 1, 1, 1], + bneck_conf(16, 3, 16, 16, False, "RE", 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2), + bneck_conf(24, 3, 72, 24, False, "RE", 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2), + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2), + bneck_conf(80, 3, 200, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1), + bneck_conf(112, 5, 672, 160, True, "HS", 2), + bneck_conf(160, 5, 960, 160, True, "HS", 1), + bneck_conf(160, 5, 960, 160, True, "HS", 1), ] last_channel = 1280 else: inverted_residual_setting = [ - # in, kernel, exp, out, use_se, use_hs, stride - [16, 3, 16, 16, 1, 0, 2], - [16, 3, 72, 24, 0, 0, 2], - [24, 3, 88, 24, 0, 0, 1], - [24, 5, 96, 40, 1, 1, 2], - [40, 5, 240, 40, 1, 1, 1], - [40, 5, 240, 40, 1, 1, 1], - [40, 5, 120, 48, 1, 1, 1], - [48, 5, 144, 48, 1, 1, 1], - [48, 5, 288, 96, 1, 1, 2], - [96, 5, 576, 96, 1, 1, 1], - [96, 5, 576, 96, 1, 1, 1], + bneck_conf(16, 3, 16, 16, True, "RE", 2), + bneck_conf(16, 3, 72, 24, False, "RE", 2), + bneck_conf(24, 3, 88, 24, False, "RE", 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2), + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1), + bneck_conf(48, 5, 288, 96, True, "HS", 2), + bneck_conf(96, 5, 576, 96, True, "HS", 1), + bneck_conf(96, 5, 576, 96, True, "HS", 1), ] last_channel = 1024 - - # apply multipler on: in, exp, out columns - for row in inverted_residual_setting: - for id in (0, 2, 3): - row[id] = _make_divisible(row[id] * width_mult) last_channel = _make_divisible(last_channel * width_mult) return MobileNetV3(inverted_residual_setting, last_channel) From 834b185be63fd8241d5827148bffa204a84642d8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 12:40:40 +0000 Subject: [PATCH 04/23] Adding first conv layers. --- torchvision/models/mobilenetv3.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3382a057d36..c3a6e83f31e 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -4,7 +4,7 @@ from typing import Callable, List, Optional -def _make_divisible(v: float, divisor: int = 8, min_value: Optional[int] = None) -> int: +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 @@ -58,7 +58,7 @@ class SqueezeExcitation(nn.Module): def __init__(self, input_channels: int, squeeze_factor: int = 4): super().__init__() - squeeze_channels = _make_divisible(input_channels // squeeze_factor) + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) @@ -74,10 +74,10 @@ def forward(self, input: Tensor) -> Tensor: class InvertedResidualConfig: def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool, activation: str, stride: int, width_mult: float): - self.input_channels = _make_divisible(input_channels * width_mult) + self.input_channels = _make_divisible(input_channels * width_mult, 8) self.kernel = kernel - self.expanded_channels = _make_divisible(expanded_channels * width_mult) - self.output_channels = _make_divisible(output_channels * width_mult) + self.expanded_channels = _make_divisible(expanded_channels * width_mult, 8) + self.output_channels = _make_divisible(output_channels * width_mult, 8) self.use_se = use_se self.use_hs = activation == "HS" self.stride = stride @@ -91,7 +91,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels - layers: List[nn.Module] = [] + layers = [] # expand if cnf.expanded_channels != cnf.input_channels: layers.extend([ @@ -132,20 +132,22 @@ def __init__( inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, num_classes: int = 1000, - blocks: Optional[List[Callable[..., nn.Module]]] = None, + block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: super().__init__() - if blocks is None: - blocks = [SqueezeExcitation, InvertedResidual] - se_layer, bottleneck_layer = blocks + if block is None: + block = InvertedResidual if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) - layers: List[nn.Module] = [ - + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers = [ + nn.Conv2d(3, firstconv_output_channels, 3, stride=2, padding=1, bias=False), + norm_layer(firstconv_output_channels), + HardSwish(inplace=True), ] pass @@ -188,6 +190,7 @@ def mobilenetv3(mode: str = "large", width_mult: float = 1.0): bneck_conf(96, 5, 576, 96, True, "HS", 1), ] last_channel = 1024 - last_channel = _make_divisible(last_channel * width_mult) + + last_channel = _make_divisible(last_channel * width_mult, 8) return MobileNetV3(inverted_residual_setting, last_channel) From 1edd16b0f5178cd0e63d62126912e915e38990d6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 12:49:58 +0000 Subject: [PATCH 05/23] Moving mobilenet.py to mobilenetv2.py --- torchvision/models/{mobilenet.py => mobilenetv2.py} | 0 torchvision/models/quantization/{mobilenet.py => mobilenetv2.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename torchvision/models/{mobilenet.py => mobilenetv2.py} (100%) rename torchvision/models/quantization/{mobilenet.py => mobilenetv2.py} (100%) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenetv2.py similarity index 100% rename from torchvision/models/mobilenet.py rename to torchvision/models/mobilenetv2.py diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenetv2.py similarity index 100% rename from torchvision/models/quantization/mobilenet.py rename to torchvision/models/quantization/mobilenetv2.py From 2f52f0dd460a77a2336dc20c1d02083a9069aa96 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 13:00:59 +0000 Subject: [PATCH 06/23] Adding mobilenet.py for BC. --- torchvision/models/mobilenet.py | 1 + torchvision/models/quantization/mobilenet.py | 1 + torchvision/models/quantization/mobilenetv2.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 torchvision/models/mobilenet.py create mode 100644 torchvision/models/quantization/mobilenet.py diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py new file mode 100644 index 00000000000..75e2a9a24bc --- /dev/null +++ b/torchvision/models/mobilenet.py @@ -0,0 +1 @@ +from .mobilenetv2 import * diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py new file mode 100644 index 00000000000..75e2a9a24bc --- /dev/null +++ b/torchvision/models/quantization/mobilenet.py @@ -0,0 +1 @@ +from .mobilenetv2 import * diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 1d14410f376..72c522a2e46 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,6 +1,6 @@ from torch import nn from torchvision.models.utils import load_state_dict_from_url -from torchvision.models.mobilenet import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls +from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls from torch.quantization import QuantStub, DeQuantStub, fuse_modules from .utils import _replace_relu, quantize_model From bb2ec9e21408390e36515961fbfcfab9e332ffdb Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 13:04:29 +0000 Subject: [PATCH 07/23] Extending ConvBNReLU for reuse. --- torchvision/models/mobilenetv2.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index d90b3f8ef14..990429bacf9 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -32,7 +32,7 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> return new_v -class ConvBNReLU(nn.Sequential): +class ConvBNActivation(nn.Sequential): def __init__( self, in_planes: int, @@ -40,18 +40,25 @@ def __init__( kernel_size: int = 3, stride: int = 1, groups: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: padding = (kernel_size - 1) // 2 if norm_layer is None: norm_layer = nn.BatchNorm2d + if activation_layer is None: + activation_layer = nn.ReLU6 super(ConvBNReLU, self).__init__( nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), norm_layer(out_planes), - nn.ReLU6(inplace=True) + activation_layer(inplace=True) ) +# necessary for backwards compatibility +ConvBNReLU = ConvBNActivation + + class InvertedResidual(nn.Module): def __init__( self, From e95ee5c9d2896d8814fe9f8c3daf6d166114235a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 12:49:58 +0000 Subject: [PATCH 08/23] Moving mobilenet.py to mobilenetv2.py --- torchvision/models/{mobilenet.py => mobilenetv2.py} | 0 torchvision/models/quantization/{mobilenet.py => mobilenetv2.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename torchvision/models/{mobilenet.py => mobilenetv2.py} (100%) rename torchvision/models/quantization/{mobilenet.py => mobilenetv2.py} (100%) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenetv2.py similarity index 100% rename from torchvision/models/mobilenet.py rename to torchvision/models/mobilenetv2.py diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenetv2.py similarity index 100% rename from torchvision/models/quantization/mobilenet.py rename to torchvision/models/quantization/mobilenetv2.py From 2ebe8baf701d292d0763a3f236ea4c46b8ddf114 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 13:00:59 +0000 Subject: [PATCH 09/23] Adding mobilenet.py for BC. --- torchvision/models/mobilenet.py | 1 + torchvision/models/quantization/mobilenet.py | 1 + torchvision/models/quantization/mobilenetv2.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 torchvision/models/mobilenet.py create mode 100644 torchvision/models/quantization/mobilenet.py diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py new file mode 100644 index 00000000000..75e2a9a24bc --- /dev/null +++ b/torchvision/models/mobilenet.py @@ -0,0 +1 @@ +from .mobilenetv2 import * diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py new file mode 100644 index 00000000000..75e2a9a24bc --- /dev/null +++ b/torchvision/models/quantization/mobilenet.py @@ -0,0 +1 @@ +from .mobilenetv2 import * diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 1d14410f376..72c522a2e46 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,6 +1,6 @@ from torch import nn from torchvision.models.utils import load_state_dict_from_url -from torchvision.models.mobilenet import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls +from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls from torch.quantization import QuantStub, DeQuantStub, fuse_modules from .utils import _replace_relu, quantize_model From 0c31a3341c23f0c7ef077e8af052f67932398bff Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 13:04:29 +0000 Subject: [PATCH 10/23] Extending ConvBNReLU for reuse. --- torchvision/models/mobilenetv2.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index d90b3f8ef14..990429bacf9 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -32,7 +32,7 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> return new_v -class ConvBNReLU(nn.Sequential): +class ConvBNActivation(nn.Sequential): def __init__( self, in_planes: int, @@ -40,18 +40,25 @@ def __init__( kernel_size: int = 3, stride: int = 1, groups: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: padding = (kernel_size - 1) // 2 if norm_layer is None: norm_layer = nn.BatchNorm2d + if activation_layer is None: + activation_layer = nn.ReLU6 super(ConvBNReLU, self).__init__( nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), norm_layer(out_planes), - nn.ReLU6(inplace=True) + activation_layer(inplace=True) ) +# necessary for backwards compatibility +ConvBNReLU = ConvBNActivation + + class InvertedResidual(nn.Module): def __init__( self, From db7522b1e3adb5824462a0e01e1a65ae6c7537e9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 13:59:43 +0000 Subject: [PATCH 11/23] Reduce import scope on mobilenet to only the public and versioned classes and methods. --- torchvision/models/mobilenet.py | 2 +- torchvision/models/quantization/mobilenet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 75e2a9a24bc..8be33d71a4e 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1 +1 @@ -from .mobilenetv2 import * +from .mobilenetv2 import MobileNetV2, mobilenet_v2 diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index 75e2a9a24bc..8be33d71a4e 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -1 +1 @@ -from .mobilenetv2 import * +from .mobilenetv2 import MobileNetV2, mobilenet_v2 From 16f55f5524b75ff9a4e85fe615891e1f1c61fb11 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 15:09:21 +0000 Subject: [PATCH 12/23] Further simplify by reusing MobileNetv2 methods. --- torchvision/models/mobilenetv3.py | 66 +++++++++++-------------------- 1 file changed, 24 insertions(+), 42 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index c3a6e83f31e..f11258d1ec2 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,29 +1,11 @@ +from .mobilenetv2 import _make_divisible, ConvBNActivation + from functools import partial from torch import nn, Tensor from torch.nn import functional as F from typing import Callable, List, Optional -def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - :param v: - :param divisor: - :param min_value: - :return: - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - class _InplaceActivation(nn.Module): def __init__(self, inplace: bool = False): @@ -34,6 +16,12 @@ def extra_repr(self) -> str: return 'inplace=True' if self.inplace else '' +class Identity(_InplaceActivation): + + def forward(self, input: Tensor) -> Tensor: + return input + + def hard_sigmoid(x: Tensor, inplace: bool = False) -> Tensor: return F.relu6(x + 3.0, inplace=inplace) / 6.0 @@ -92,29 +80,23 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels layers = [] + activation_layer = HardSwish if cnf.use_hs else nn.ReLU + # expand if cnf.expanded_channels != cnf.input_channels: - layers.extend([ - nn.Conv2d(cnf.input_channels, cnf.expanded_channels, 1, bias=False), - norm_layer(cnf.expanded_channels), - HardSwish(inplace=True) if cnf.use_hs else nn.ReLU(inplace=True), - ]) + layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise - layers.extend([ - nn.Conv2d(cnf.expanded_channels, cnf.expanded_channels, cnf.kernel, stride=cnf.stride, - padding=(cnf.kernel - 1) // 2, groups=cnf.expanded_channels, bias=False), - norm_layer(cnf.expanded_channels), - HardSwish(inplace=True) if cnf.use_hs else nn.ReLU(inplace=True), - ]) + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, + stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, + activation_layer=activation_layer)) if cnf.use_se: layers.append(SqueezeExcitation(cnf.expanded_channels)) # project - layers.extend([ - nn.Conv2d(cnf.expanded_channels, cnf.output_channels, 1, bias=False), - norm_layer(cnf.expanded_channels), - ]) + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.output_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=Identity)) self.block = nn.Sequential(*layers) @@ -144,17 +126,17 @@ def __init__( norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) firstconv_output_channels = inverted_residual_setting[0].input_channels - layers = [ - nn.Conv2d(3, firstconv_output_channels, 3, stride=2, padding=1, bias=False), - norm_layer(firstconv_output_channels), - HardSwish(inplace=True), - ] + layers = [ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=HardSwish)] - pass # TODO: initialize weights -def mobilenetv3(mode: str = "large", width_mult: float = 1.0): +# TODO: add doc strings and add it in document files +# TODO: tests +# TODO: add it in hubconf.py +# TODO: pretrained +def mobilenet_v3(mode: str = "large", width_mult: float = 1.0): bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) if mode == "large": inverted_residual_setting = [ From 8162fa41af2dccea932b9653ffe9df29c7cb851a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 16:05:31 +0000 Subject: [PATCH 13/23] Adding the remaining implementation of mobilenetv3. --- torchvision/models/mobilenetv3.py | 53 +++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index f11258d1ec2..63ec13cee97 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,10 +1,12 @@ -from .mobilenetv2 import _make_divisible, ConvBNActivation +import torch from functools import partial from torch import nn, Tensor from torch.nn import functional as F from typing import Callable, List, Optional +from .mobilenetv2 import _make_divisible, ConvBNActivation + class _InplaceActivation(nn.Module): @@ -60,6 +62,7 @@ def forward(self, input: Tensor) -> Tensor: class InvertedResidualConfig: + def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool, activation: str, stride: int, width_mult: float): self.input_channels = _make_divisible(input_channels * width_mult, 8) @@ -119,17 +122,63 @@ def __init__( ) -> None: super().__init__() + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + if block is None: block = InvertedResidual if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers = [ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=HardSwish)] - # TODO: initialize weights + # building inverted residual blocks + for cnf in inverted_residual_setting: + layers.append(block(cnf, norm_layer)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].output_channels + lastconv_output_channels = 6 * lastconv_input_channels + layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=HardSwish)) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Sequential( + nn.Linear(lastconv_output_channels, last_channel), + HardSwish(inplace=True), + nn.Dropout(p=0.2), + nn.Linear(last_channel, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) # TODO: add doc strings and add it in document files From 8615585b137fd14ba8c4dec5513c3d6c03fccd47 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 17:47:48 +0000 Subject: [PATCH 14/23] Adding tests, docs and init methods. --- hubconf.py | 3 +- ...lTester.test_mobilenet_v3_large_expect.pkl | Bin 0 -> 953 bytes ...lTester.test_mobilenet_v3_small_expect.pkl | Bin 0 -> 953 bytes test/test_models.py | 17 +- torchvision/models/mobilenet.py | 1 + torchvision/models/mobilenetv3.py | 150 ++++++++++++------ 6 files changed, 116 insertions(+), 55 deletions(-) create mode 100644 test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl create mode 100644 test/expect/ModelTester.test_mobilenet_v3_small_expect.pkl diff --git a/hubconf.py b/hubconf.py index 79c22bd938b..dec4a7fb196 100644 --- a/hubconf.py +++ b/hubconf.py @@ -11,7 +11,8 @@ from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn from torchvision.models.googlenet import googlenet from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 -from torchvision.models.mobilenet import mobilenet_v2 +from torchvision.models.mobilenetv2 import mobilenet_v2 +from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ mnasnet1_3 diff --git a/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9691daf18c7c691a970ac8204305a28061a4b394 GIT binary patch literal 953 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf*+{H(8R#h(7+4?&CM;13YjCAfuhL;rG+fsMkR%;f!++>jNV3Vg>2qN-i&R9>>bI; z`8heM$t9WjdBt4*MJcI8sVOd*$t7Hc9GXFeoF#=^L519U0p9E!7Z**lSphT+gadH; zTZVxT#ozj9PI7bNLJETmd>U(d(^K@U%C6{f|C3)=?4PmDCVIiT?EhYRGyE*{td8DT zSF(Kj+RNWn*83bWT^BZY=DI%a-)kFE^tB8cGS{xlY+ak9`9N>?TeEe)zDTU!*LG@M zmf_!Z8CzZTc0Cr>^JqA;POYj^Z(pvH-u~$ox=IS?^-etWS^I@^u};7o9=$H3rgf9O z4(R3Fc3!`BDeL-^wwHA$-CnwGu3EUhqvnNm8C=bJ@lO}6i)i_vWAIFNosB-9-XV<@ z>!xhJqx;FSa$SfJtFGGbW#I6+bM3*i6Tq+mVcg-v&tMG?pR&}VVqmztIhi8`2}=4P z#4+Xq$IQI+P$r;ieH6lFm<0AR$n$KVoWc+eRlo>j2Y9ox Nfy9`B5TqWW767`o`& zf*+{H(8R#h(7?jn%+$=nz|5qOIf5A|np{v?$l`5OQpg(U&EU=GZPZrC=56H7*jC8i zk(`{LlarcUl9``Z%;jH{l3J9S;*yzM!d1wj8C1wwQpgol$gLOP&CZeF9)5cT&@>PZ z!0B%p20j#j>!Ufz&4~*s4Dz={uT#GJYkk|{b>EJ)vW8wHa%JQ_FIGg zp_+;7nHROMe;0d9-|M;CdM`Wo4SD)T}KtvjysOMlr9(e)jga_cLew(2`JMCf<_ zU%xKw+-m)xtvu^jz2B;TApEzEQRF4PGn@GgI6kaffB0+RI}yLBzv~;#UZB5hQ`Gw9-k0^|<_oQ#yX5w|uP0XN&-i&+cbcQ@`g^C1 z*01GStv`qR&bp$42kVUUOZ7Qlz66KQ^fIO+CxBrC!nnhSpTQa)K4qyz#lUcPb23K? z5|s2oh-1tJikXY^(nFbmwt{egHzSCGr%B{k697q|0Q3}!t{d4;GAKHK0C~u|(c>B2 zBxE;&QYr#S0No2S2^t~+-fV0-P!)2_x^T6i#0&y3`Y43UFbV8skmuPzIfWq{s(=y5 R4)A7W1Bo#MAxJ$$EdbN+|G@wN literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index 893ed8ab38e..09d2410aa68 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -275,16 +275,17 @@ def test_mobilenetv2_residual_setting(self): out = model(x) self.assertEqual(out.shape[-1], 1000) - def test_mobilenetv2_norm_layer(self): - model = models.__dict__["mobilenet_v2"]() - self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + def test_mobilenet_norm_layer(self): + for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]: + model = models.__dict__[name]() + self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + def get_gn(num_channels): + return nn.GroupNorm(32, num_channels) - model = models.__dict__["mobilenet_v2"](norm_layer=get_gn) - self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) + model = models.__dict__[name](norm_layer=get_gn) + self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) def test_inceptionv3_eval(self): # replacement for models.inception_v3(pretrained=True) that does not download weights diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 8be33d71a4e..0b3fd886b6f 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1 +1,2 @@ from .mobilenetv2 import MobileNetV2, mobilenet_v2 +from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 63ec13cee97..3685e0950b2 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -3,9 +3,20 @@ from functools import partial from torch import nn, Tensor from torch.nn import functional as F -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional -from .mobilenetv2 import _make_divisible, ConvBNActivation +from torchvision.models.utils import load_state_dict_from_url +from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation + + +__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] + + +# TODO: add pretrained +model_urls = { + "mobilenet_v3_large_1_0": None, + "mobilenet_v3_small_1_0": None, +} class _InplaceActivation(nn.Module): @@ -35,7 +46,7 @@ def forward(self, input: Tensor) -> Tensor: def hard_swish(x: Tensor, inplace: bool = False) -> Tensor: - return x * hard_swish(x, inplace=inplace) + return x * hard_sigmoid(x, inplace=inplace) class HardSwish(_InplaceActivation): @@ -120,6 +131,16 @@ def __init__( block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: + """ + MobileNet V3 main class + + Args: + inverted_residual_setting (List[InvertedResidualConfig]): Network structure + last_channel (int): The number of channels on the penultimate layer + num_classes (int): Number of classes + block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + """ super().__init__() if not inverted_residual_setting: @@ -181,47 +202,84 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -# TODO: add doc strings and add it in document files -# TODO: tests -# TODO: add it in hubconf.py -# TODO: pretrained -def mobilenet_v3(mode: str = "large", width_mult: float = 1.0): +def _mobilenet_v3( + arch: str, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + pretrained: bool, + progress: bool, + **kwargs: Any +): + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: + """ + Constructs a large MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + lastchannel_conf = lambda c: _make_divisible(c * width_mult, 8) + + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, False, "RE", 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2), + bneck_conf(24, 3, 72, 24, False, "RE", 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2), + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2), + bneck_conf(80, 3, 200, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1), + bneck_conf(112, 5, 672, 160, True, "HS", 2), + bneck_conf(160, 5, 960, 160, True, "HS", 1), + bneck_conf(160, 5, 960, 160, True, "HS", 1), + ] + last_channel = lastchannel_conf(1280) + + return _mobilenet_v3("mobilenet_v3_large_1_0", inverted_residual_setting, last_channel, pretrained, progress, + **kwargs) + + +def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: + """ + Constructs a small MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) - if mode == "large": - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, False, "RE", 1), - bneck_conf(16, 3, 64, 24, False, "RE", 2), - bneck_conf(24, 3, 72, 24, False, "RE", 1), - bneck_conf(24, 5, 72, 40, True, "RE", 2), - bneck_conf(40, 5, 120, 40, True, "RE", 1), - bneck_conf(40, 5, 120, 40, True, "RE", 1), - bneck_conf(40, 3, 240, 80, False, "HS", 2), - bneck_conf(80, 3, 200, 80, False, "HS", 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1), - bneck_conf(80, 3, 480, 112, True, "HS", 1), - bneck_conf(112, 3, 672, 112, True, "HS", 1), - bneck_conf(112, 5, 672, 160, True, "HS", 2), - bneck_conf(160, 5, 960, 160, True, "HS", 1), - bneck_conf(160, 5, 960, 160, True, "HS", 1), - ] - last_channel = 1280 - else: - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, True, "RE", 2), - bneck_conf(16, 3, 72, 24, False, "RE", 2), - bneck_conf(24, 3, 88, 24, False, "RE", 1), - bneck_conf(24, 5, 96, 40, True, "HS", 2), - bneck_conf(40, 5, 240, 40, True, "HS", 1), - bneck_conf(40, 5, 240, 40, True, "HS", 1), - bneck_conf(40, 5, 120, 48, True, "HS", 1), - bneck_conf(48, 5, 144, 48, True, "HS", 1), - bneck_conf(48, 5, 288, 96, True, "HS", 2), - bneck_conf(96, 5, 576, 96, True, "HS", 1), - bneck_conf(96, 5, 576, 96, True, "HS", 1), - ] - last_channel = 1024 - - last_channel = _make_divisible(last_channel * width_mult, 8) - - return MobileNetV3(inverted_residual_setting, last_channel) + lastchannel_conf = lambda c: _make_divisible(c * width_mult, 8) + + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, True, "RE", 2), + bneck_conf(16, 3, 72, 24, False, "RE", 2), + bneck_conf(24, 3, 88, 24, False, "RE", 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2), + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1), + bneck_conf(48, 5, 288, 96, True, "HS", 2), + bneck_conf(96, 5, 576, 96, True, "HS", 1), + bneck_conf(96, 5, 576, 96, True, "HS", 1), + ] + last_channel = lastchannel_conf(1024) + + return _mobilenet_v3("mobilenet_v3_small_1_0", inverted_residual_setting, last_channel, pretrained, progress, + **kwargs) From 8664fdecef6648536f8d2a54b0280d3377a78ef9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 17:59:33 +0000 Subject: [PATCH 15/23] Refactoring and fixing formatter. --- torchvision/models/mobilenetv3.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3685e0950b2..3abd639c7a5 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -76,14 +76,18 @@ class InvertedResidualConfig: def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool, activation: str, stride: int, width_mult: float): - self.input_channels = _make_divisible(input_channels * width_mult, 8) + self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel - self.expanded_channels = _make_divisible(expanded_channels * width_mult, 8) - self.output_channels = _make_divisible(output_channels * width_mult, 8) + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) + self.output_channels = self.adjust_channels(output_channels, width_mult) self.use_se = use_se self.use_hs = activation == "HS" self.stride = stride + @staticmethod + def adjust_channels(channels: int, width_mult: float): + return _make_divisible(channels * width_mult, 8) + class InvertedResidual(nn.Module): @@ -228,7 +232,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs """ width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) - lastchannel_conf = lambda c: _make_divisible(c * width_mult, 8) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, False, "RE", 1), @@ -247,7 +251,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs bneck_conf(160, 5, 960, 160, True, "HS", 1), bneck_conf(160, 5, 960, 160, True, "HS", 1), ] - last_channel = lastchannel_conf(1280) + last_channel = adjust_channels(1280) return _mobilenet_v3("mobilenet_v3_large_1_0", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) @@ -264,7 +268,7 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs """ width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) - lastchannel_conf = lambda c: _make_divisible(c * width_mult, 8) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, True, "RE", 2), @@ -279,7 +283,7 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs bneck_conf(96, 5, 576, 96, True, "HS", 1), bneck_conf(96, 5, 576, 96, True, "HS", 1), ] - last_channel = lastchannel_conf(1024) + last_channel = adjust_channels(1024) return _mobilenet_v3("mobilenet_v3_small_1_0", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) From cfa20b73dd3a2c52b15596471b62ffbb83f9ca40 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 18:02:32 +0000 Subject: [PATCH 16/23] Fixing type issues. --- torchvision/models/mobilenetv3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3abd639c7a5..cec899a636e 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -97,7 +97,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels - layers = [] + layers: List[nn.Module] = [] activation_layer = HardSwish if cnf.use_hs else nn.ReLU # expand @@ -156,10 +156,12 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + layers: List[nn.Module] = [] + # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers = [ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=HardSwish)] + layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=HardSwish)) # building inverted residual blocks for cnf in inverted_residual_setting: From c189ae15c36ea6ddeb100c10661e42888734cf32 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Dec 2020 20:03:02 +0000 Subject: [PATCH 17/23] Using build-in Hardsigmoid and Hardswish. --- torchvision/models/mobilenetv3.py | 38 +++++-------------------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index cec899a636e..c4669b2ef8d 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -19,42 +19,16 @@ } -class _InplaceActivation(nn.Module): +class Identity(nn.Module): def __init__(self, inplace: bool = False): super().__init__() self.inplace = inplace - def extra_repr(self) -> str: - return 'inplace=True' if self.inplace else '' - - -class Identity(_InplaceActivation): - def forward(self, input: Tensor) -> Tensor: return input -def hard_sigmoid(x: Tensor, inplace: bool = False) -> Tensor: - return F.relu6(x + 3.0, inplace=inplace) / 6.0 - - -class HardSigmoid(_InplaceActivation): - - def forward(self, input: Tensor) -> Tensor: - return hard_sigmoid(input, inplace=self.inplace) - - -def hard_swish(x: Tensor, inplace: bool = False) -> Tensor: - return x * hard_sigmoid(x, inplace=inplace) - - -class HardSwish(_InplaceActivation): - - def forward(self, input: Tensor) -> Tensor: - return hard_swish(input, inplace=self.inplace) - - class SqueezeExcitation(nn.Module): def __init__(self, input_channels: int, squeeze_factor: int = 4): @@ -68,7 +42,7 @@ def forward(self, input: Tensor) -> Tensor: scale = self.fc1(scale) scale = F.relu(scale, inplace=True) scale = self.fc2(scale) - scale = hard_sigmoid(scale, inplace=True) + scale = F.hardsigmoid(scale, inplace=True) return scale * input @@ -98,7 +72,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels layers: List[nn.Module] = [] - activation_layer = HardSwish if cnf.use_hs else nn.ReLU + activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU # expand if cnf.expanded_channels != cnf.input_channels: @@ -161,7 +135,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=HardSwish)) + activation_layer=nn.Hardswish)) # building inverted residual blocks for cnf in inverted_residual_setting: @@ -171,13 +145,13 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].output_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=HardSwish)) + norm_layer=norm_layer, activation_layer=nn.Hardswish)) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Linear(lastconv_output_channels, last_channel), - HardSwish(inplace=True), + nn.Hardswish(inplace=True), nn.Dropout(p=0.2), nn.Linear(last_channel, num_classes), ) From 9a758a83630702920c8856a1aa8f69a7b80c3711 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Dec 2020 10:50:14 +0000 Subject: [PATCH 18/23] Code review nits. --- torchvision/models/mobilenetv3.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index c4669b2ef8d..ee2fdd5ab17 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -3,7 +3,7 @@ from functools import partial from torch import nn, Tensor from torch.nn import functional as F -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Sequence from torchvision.models.utils import load_state_dict_from_url from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation @@ -14,8 +14,8 @@ # TODO: add pretrained model_urls = { - "mobilenet_v3_large_1_0": None, - "mobilenet_v3_small_1_0": None, + "mobilenet_v3_large": None, + "mobilenet_v3_small": None, } @@ -67,7 +67,8 @@ class InvertedResidual(nn.Module): def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): super().__init__() - assert cnf.stride in [1, 2] + if not (1 <= cnf.stride <= 2): + raise ValueError('illegal stride value') self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels @@ -123,6 +124,9 @@ def __init__( if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") + elif not (isinstance(inverted_residual_setting, Sequence) and + all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): + raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: block = InvertedResidual @@ -229,8 +233,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs ] last_channel = adjust_channels(1280) - return _mobilenet_v3("mobilenet_v3_large_1_0", inverted_residual_setting, last_channel, pretrained, progress, - **kwargs) + return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: @@ -261,5 +264,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs ] last_channel = adjust_channels(1024) - return _mobilenet_v3("mobilenet_v3_small_1_0", inverted_residual_setting, last_channel, pretrained, progress, - **kwargs) + return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) From 25f8b26087a88287a0e71c4555ce16a3466dfc8d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Dec 2020 18:20:36 +0000 Subject: [PATCH 19/23] Putting inplace on Dropout. --- torchvision/models/mobilenetv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ee2fdd5ab17..ccda83ac9db 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -156,7 +156,7 @@ def __init__( self.classifier = nn.Sequential( nn.Linear(lastconv_output_channels, last_channel), nn.Hardswish(inplace=True), - nn.Dropout(p=0.2), + nn.Dropout(p=0.2, inplace=True), nn.Linear(last_channel, num_classes), ) From 5198385f6f9ecef2b1f532df94aaa7d0c6de1610 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 1 Jan 2021 12:08:22 +0000 Subject: [PATCH 20/23] Adding rmsprop support on the train.py --- references/classification/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 789bb8134ff..77f3127782e 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -173,8 +173,15 @@ def main(args): criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + opt_name = args.opt.lower() + if opt_name == 'sgd': + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + elif opt_name == 'rmsprop': + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + else: + raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, @@ -238,6 +245,7 @@ def parse_args(): help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') + parser.add_argument('--opt', default='sgd', type=str, help='optimizer') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') From e4d130f72c069a4b54c6258fb7b45d7548f5afb1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 1 Jan 2021 15:08:07 +0000 Subject: [PATCH 21/23] Adding auto-augment and random-erase in the training scripts. --- references/classification/train.py | 37 +++++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 77f3127782e..47a7e5955e6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -79,7 +79,7 @@ def _get_cache_path(filepath): return cache_path -def load_data(traindir, valdir, cache_dataset, distributed): +def load_data(traindir, valdir, args): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], @@ -88,20 +88,28 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: + trans = [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + ] + if args.auto_augment is not None: + aa_policy = transforms.AutoAugmentPolicy(args.auto_augment) + trans.append(transforms.AutoAugment(policy=aa_policy)) + trans.extend([ + transforms.ToTensor(), + normalize, + ]) + if args.random_erase > 0: + trans.append(transforms.RandomErasing(p=args.random_erase)) dataset = torchvision.datasets.ImageFolder( traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - if cache_dataset: + transforms.Compose(trans)) + if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) @@ -109,7 +117,7 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading validation data") cache_path = _get_cache_path(valdir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed): transforms.ToTensor(), normalize, ])) - if cache_dataset: + if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") - if distributed: + if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: @@ -155,8 +163,7 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, - args.cache_dataset, args.distributed) + dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) @@ -283,6 +290,8 @@ def parse_args(): help="Use pre-trained models from the modelzoo", action="store_true", ) + parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') + parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') # Mixed precision training parameters parser.add_argument('--apex', action='store_true', From c0a13a292ee4e67b5f8d3a10e1e24aac9ee3162f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 5 Jan 2021 14:57:40 +0000 Subject: [PATCH 22/23] Adding support for reduced tail on MobileNetV3. --- torchvision/models/mobilenetv3.py | 32 +++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ccda83ac9db..d5949984c78 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -201,7 +201,8 @@ def _mobilenet_v3( return model -def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, + **kwargs: Any) -> MobileNetV3: """ Constructs a large MobileNetV3 architecture from `"Searching for MobileNetV3" `_. @@ -209,11 +210,16 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr + reduced_tail (bool): If True, reduces the channel counts of all feature layers + between C4 and C5 by 2. It is used to reduce the channel redundancy in the + backbone for Detection and Segmentation. """ width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + reduce_divider = 2 if reduced_tail else 1 + inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, False, "RE", 1), bneck_conf(16, 3, 64, 24, False, "RE", 2), @@ -227,16 +233,17 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs bneck_conf(80, 3, 184, 80, False, "HS", 1), bneck_conf(80, 3, 480, 112, True, "HS", 1), bneck_conf(112, 3, 672, 112, True, "HS", 1), - bneck_conf(112, 5, 672, 160, True, "HS", 2), - bneck_conf(160, 5, 960, 160, True, "HS", 1), - bneck_conf(160, 5, 960, 160, True, "HS", 1), + bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4 + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), ] - last_channel = adjust_channels(1280) + last_channel = adjust_channels(1280 // reduce_divider) # C5 return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) -def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, + **kwargs: Any) -> MobileNetV3: """ Constructs a small MobileNetV3 architecture from `"Searching for MobileNetV3" `_. @@ -244,11 +251,16 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr + reduced_tail (bool): If True, reduces the channel counts of all feature layers + between C4 and C5 by 2. It is used to reduce the channel redundancy in the + backbone for Detection and Segmentation. """ width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + reduce_divider = 2 if reduced_tail else 1 + inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, True, "RE", 2), bneck_conf(16, 3, 72, 24, False, "RE", 2), @@ -258,10 +270,10 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 120, 48, True, "HS", 1), bneck_conf(48, 5, 144, 48, True, "HS", 1), - bneck_conf(48, 5, 288, 96, True, "HS", 2), - bneck_conf(96, 5, 576, 96, True, "HS", 1), - bneck_conf(96, 5, 576, 96, True, "HS", 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4 + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), ] - last_channel = adjust_channels(1024) + last_channel = adjust_channels(1024 // reduce_divider) # C5 return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) From 2414d2de8fd1050f9e31af406ef5662e2a1a6c08 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 5 Jan 2021 18:36:13 +0000 Subject: [PATCH 23/23] Tagging blocks with comments. --- torchvision/models/mobilenetv3.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index d5949984c78..6282cd45434 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -222,12 +222,12 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_ inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, False, "RE", 1), - bneck_conf(16, 3, 64, 24, False, "RE", 2), + bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1 bneck_conf(24, 3, 72, 24, False, "RE", 1), - bneck_conf(24, 5, 72, 40, True, "RE", 2), + bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2 bneck_conf(40, 5, 120, 40, True, "RE", 1), bneck_conf(40, 5, 120, 40, True, "RE", 1), - bneck_conf(40, 3, 240, 80, False, "HS", 2), + bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3 bneck_conf(80, 3, 200, 80, False, "HS", 1), bneck_conf(80, 3, 184, 80, False, "HS", 1), bneck_conf(80, 3, 184, 80, False, "HS", 1), @@ -262,10 +262,10 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_ reduce_divider = 2 if reduced_tail else 1 inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, True, "RE", 2), - bneck_conf(16, 3, 72, 24, False, "RE", 2), + bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2 bneck_conf(24, 3, 88, 24, False, "RE", 1), - bneck_conf(24, 5, 96, 40, True, "HS", 2), + bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3 bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 120, 48, True, "HS", 1),