Skip to content

Commit 8b9ca53

Browse files
committed
Change ReLU6, [-1,1] rescaling, backbone init & no pretraining.
1 parent 7cce538 commit 8b9ca53

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

torchvision/models/detection/ssdlite.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
2626
return nn.Sequential(
2727
# 3x3 depthwise with stride 1 and padding 1
2828
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
29-
norm_layer=norm_layer, activation_layer=nn.ReLU),
29+
norm_layer=norm_layer, activation_layer=nn.ReLU6),
3030

3131
# 1x1 projetion to output channels
3232
nn.Conv2d(in_channels, out_channels, 1)
3333
)
3434

3535

3636
def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
37-
activation = nn.ReLU
37+
activation = nn.ReLU6
3838
intermediate_channels = out_channels // 2
3939
return nn.Sequential(
4040
# 1x1 projection to half output channels
@@ -93,7 +93,8 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C
9393

9494

9595
class SSDLiteFeatureExtractorMobileNet(nn.Module):
96-
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any):
96+
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool,
97+
**kwargs: Any):
9798
super().__init__()
9899
# non-public config parameters
99100
min_depth = kwargs.pop('_min_depth', 16)
@@ -115,8 +116,13 @@ def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., n
115116
_normal_init(extra)
116117

117118
self.extra = extra
119+
self.rescaling = rescaling
118120

119121
def forward(self, x: Tensor) -> Dict[str, Tensor]:
122+
# Rescale from [0, 1] to [-1, -1]
123+
if self.rescaling:
124+
x = 2.0 * x - 1.0
125+
120126
# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
121127
output = []
122128
for block in self.features:
@@ -131,9 +137,12 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
131137

132138

133139
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
134-
norm_layer: Callable[..., nn.Module], **kwargs: Any):
140+
norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any):
135141
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
136142
norm_layer=norm_layer, **kwargs).features
143+
if not pretrained:
144+
# Change the default initialization scheme if not pretrained
145+
_normal_init(backbone)
137146

138147
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
139148
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
@@ -148,11 +157,11 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
148157
for parameter in b.parameters():
149158
parameter.requires_grad_(False)
150159

151-
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
160+
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs)
152161

153162

154163
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
155-
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None,
164+
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
156165
norm_layer: Optional[Callable[..., nn.Module]] = None,
157166
**kwargs: Any):
158167
trainable_backbone_layers = _validate_trainable_layers(
@@ -161,11 +170,13 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
161170
if pretrained:
162171
pretrained_backbone = False
163172

173+
rescaling = not pretrained_backbone
174+
164175
if norm_layer is None:
165176
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
166177

167178
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
168-
norm_layer, _width_mult=1.0)
179+
norm_layer, rescaling, _width_mult=1.0)
169180

170181
size = (320, 320)
171182
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
@@ -181,7 +192,8 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
181192
}
182193
kwargs = {**defaults, **kwargs}
183194
model = SSD(backbone, anchor_generator, size, num_classes,
184-
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs)
195+
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
196+
image_mean=[0., 0., 0.], image_std=[1., 1., 1.], **kwargs)
185197

186198
if pretrained:
187199
weights_name = 'ssdlite320_mobilenet_v3_large_coco'

0 commit comments

Comments
 (0)