@@ -93,9 +93,11 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C
93
93
94
94
95
95
class SSDLiteFeatureExtractorMobileNet (nn .Module ):
96
- def __init__ (self , backbone : nn .Module , c4_pos : int , norm_layer : Callable [..., nn .Module ],
97
- width_mult : float = 1.0 , min_depth : int = 16 ):
96
+ def __init__ (self , backbone : nn .Module , c4_pos : int , norm_layer : Callable [..., nn .Module ], ** kwargs : Any ):
98
97
super ().__init__ ()
98
+ # non-public config parameters
99
+ min_depth = kwargs .pop ('_min_depth' , 16 )
100
+ width_mult = kwargs .pop ('_width_mult' , 1.0 )
99
101
100
102
assert not backbone [c4_pos ].use_res_connect
101
103
self .features = nn .Sequential (
@@ -129,10 +131,9 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
129
131
130
132
131
133
def _mobilenet_extractor (backbone_name : str , progress : bool , pretrained : bool , trainable_layers : int ,
132
- norm_layer : Callable [..., nn .Module ]):
133
- # TODO: support width_mult
134
+ norm_layer : Callable [..., nn .Module ], ** kwargs : Any ):
134
135
backbone = mobilenet .__dict__ [backbone_name ](pretrained = pretrained , progress = progress ,
135
- norm_layer = norm_layer ).features
136
+ norm_layer = norm_layer , ** kwargs ).features
136
137
137
138
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
138
139
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
@@ -147,7 +148,7 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
147
148
for parameter in b .parameters ():
148
149
parameter .requires_grad_ (False )
149
150
150
- return SSDLiteFeatureExtractorMobileNet (backbone , stage_indices [- 2 ], norm_layer )
151
+ return SSDLiteFeatureExtractorMobileNet (backbone , stage_indices [- 2 ], norm_layer , ** kwargs )
151
152
152
153
153
154
def ssdlite320_mobilenet_v3_large (pretrained : bool = False , progress : bool = True , num_classes : int = 91 ,
@@ -164,7 +165,7 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
164
165
norm_layer = partial (nn .BatchNorm2d , eps = 0.001 , momentum = 0.03 )
165
166
166
167
backbone = _mobilenet_extractor ("mobilenet_v3_large" , progress , pretrained_backbone , trainable_backbone_layers ,
167
- norm_layer )
168
+ norm_layer , _width_mult = 1.0 )
168
169
169
170
size = (320 , 320 )
170
171
anchor_generator = DefaultBoxGenerator ([[2 , 3 ] for _ in range (6 )], min_ratio = 0.2 , max_ratio = 0.95 )
0 commit comments