1
1
import torch
2
2
import torch .nn .functional as F
3
3
4
+ from collections import OrderedDict
4
5
from torch import nn , Tensor
5
6
from typing import Any , Dict , List , Optional , Tuple
6
7
7
8
from . import _utils as det_utils
8
9
from .anchor_utils import DBoxGenerator
9
- from .transform import GeneralizedRCNNTransform
10
10
from .backbone_utils import _validate_trainable_layers
11
+ from .transform import GeneralizedRCNNTransform
11
12
from .. import vgg
12
13
14
+ from .retinanet import RetinaNet , RetinaNetHead # TODO: Refactor both to inherit properly
15
+
13
16
14
17
__all__ = ['SSD' ]
15
18
16
19
17
- class SSDHead (nn .Module ):
18
- # TODO: Similar to RetinaNetHead. Perhaps abstract and reuse for one-shot detectors.
20
+ class SSDHead (RetinaNetHead ):
19
21
def __init__ (self , in_channels : List [int ], num_anchors : List [int ], num_classes : int ):
20
- super (). __init__ ()
22
+ nn . Module . __init__ (self )
21
23
self .classification_head = SSDClassificationHead (in_channels , num_anchors , num_classes )
22
24
self .regression_head = SSDRegressionHead (in_channels , num_anchors )
23
25
24
- def compute_loss (self , targets : List [Dict [str , Tensor ]], head_outputs : Dict [str , Tensor ], anchors : List [Tensor ],
25
- matched_idxs : List [Tensor ]) -> Dict [str , Tensor ]:
26
- return {
27
- 'classification' : self .classification_head .compute_loss (targets , head_outputs , matched_idxs ),
28
- 'bbox_regression' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
29
- }
30
-
31
- def forward (self , x : List [Tensor ]) -> Dict [str , Tensor ]:
32
- return {
33
- 'cls_logits' : self .classification_head (x ),
34
- 'bbox_regression' : self .regression_head (x )
35
- }
36
-
37
26
38
27
class SSDClassificationHead (nn .Module ):
39
28
def __init__ (self , in_channels : List [int ], num_anchors : List [int ], num_classes : int ):
@@ -65,7 +54,7 @@ def forward(self, x: List[Tensor]) -> Tensor:
65
54
pass
66
55
67
56
68
- class SSD (nn . Module ):
57
+ class SSD (RetinaNet ):
69
58
def __init__ (self , backbone : nn .Module , num_classes : int ,
70
59
size : int = 300 , image_mean : Optional [List [float ]] = None , image_std : Optional [List [float ]] = None ,
71
60
aspect_ratios : Optional [List [List [int ]]] = None ,
@@ -74,15 +63,15 @@ def __init__(self, backbone: nn.Module, num_classes: int,
74
63
detections_per_img : int = 200 ,
75
64
iou_thresh : float = 0.5 ,
76
65
topk_candidates : int = 400 ):
77
- super (). __init__ ()
66
+ nn . Module . __init__ (self )
78
67
79
68
if aspect_ratios is None :
80
69
aspect_ratios = [[2 ], [2 , 3 ], [2 , 3 ], [2 , 3 ], [2 ], [2 ]]
81
70
82
71
# Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
83
72
device = next (backbone .parameters ()).device
84
73
tmp_img = torch .empty ((1 , 3 , size , size ), device = device )
85
- tmp_sizes = [x .size () for x in backbone (tmp_img )]
74
+ tmp_sizes = [x .size () for x in backbone (tmp_img ). values () ]
86
75
out_channels = [x [1 ] for x in tmp_sizes ]
87
76
feature_map_sizes = [x [2 ] for x in tmp_sizes ]
88
77
@@ -118,26 +107,10 @@ def __init__(self, backbone: nn.Module, num_classes: int,
118
107
# used only on torchscript mode
119
108
self ._has_warned = False
120
109
121
- @torch .jit .unused
122
- def eager_outputs (self , losses , detections ):
123
- # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
124
- if self .training :
125
- return losses
126
-
127
- return detections
128
-
129
110
def compute_loss (self , targets : List [Dict [str , Tensor ]], head_outputs : Dict [str , Tensor ],
130
111
anchors : List [Tensor ]) -> Dict [str , Tensor ]:
131
112
pass
132
113
133
- def postprocess_detections (self , head_outputs : Dict [str , List [Tensor ]], anchors : List [List [Tensor ]],
134
- image_shapes : List [Tuple [int , int ]]) -> List [Dict [str , Tensor ]]:
135
- pass
136
-
137
- def forward (self , images : List [Tensor ],
138
- targets : Optional [List [Dict [str , Tensor ]]] = None ) -> Tuple [Dict [str , Tensor ], List [Dict [str , Tensor ]]]:
139
- pass
140
-
141
114
142
115
class SSDFeatureExtractorVGG (nn .Module ):
143
116
# TODO: That's the SSD300 extractor. handle the SDD500 case as well. See page 11, footernote 5.
@@ -188,7 +161,7 @@ def __init__(self, backbone: nn.Module):
188
161
nn .ReLU (inplace = True ),
189
162
)
190
163
191
- def forward (self , x : Tensor ) -> List [ Tensor ]:
164
+ def forward (self , x : Tensor ) -> Dict [ str , Tensor ]:
192
165
# L2 regularization + Rescaling of 1st block's feature map
193
166
x = self .block1 (x )
194
167
rescaled = self .scale_weight .view (1 , - 1 , 1 , 1 ) * F .normalize (x )
@@ -199,7 +172,7 @@ def forward(self, x: Tensor) -> List[Tensor]:
199
172
x = block (x )
200
173
output .append (x )
201
174
202
- return output
175
+ return OrderedDict ((( str ( i ), v ) for i , v in enumerate ( output )))
203
176
204
177
205
178
def _vgg_backbone (backbone_name : str , pretrained : bool , trainable_layers : int = 3 ):
0 commit comments