Skip to content

Commit ea46bfc

Browse files
committed
Add support of support width_mult
1 parent 34c2769 commit ea46bfc

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torchvision/models/detection/ssdlite.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,11 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C
9393

9494

9595
class SSDLiteFeatureExtractorMobileNet(nn.Module):
96-
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module],
97-
width_mult: float = 1.0, min_depth: int = 16):
96+
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any):
9897
super().__init__()
98+
# non-public config parameters
99+
min_depth = kwargs.pop('_min_depth', 16)
100+
width_mult = kwargs.pop('_width_mult', 1.0)
99101

100102
assert not backbone[c4_pos].use_res_connect
101103
self.features = nn.Sequential(
@@ -129,10 +131,9 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
129131

130132

131133
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
132-
norm_layer: Callable[..., nn.Module]):
133-
# TODO: support width_mult
134+
norm_layer: Callable[..., nn.Module], **kwargs: Any):
134135
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
135-
norm_layer=norm_layer).features
136+
norm_layer=norm_layer, **kwargs).features
136137

137138
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
138139
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
@@ -147,7 +148,7 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
147148
for parameter in b.parameters():
148149
parameter.requires_grad_(False)
149150

150-
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
151+
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
151152

152153

153154
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
@@ -164,7 +165,7 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
164165
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
165166

166167
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
167-
norm_layer)
168+
norm_layer, _width_mult=1.0)
168169

169170
size = (320, 320)
170171
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)

0 commit comments

Comments
 (0)