Skip to content

Commit 34237e4

Browse files
committed
Temporarily inherit from Retina to avoid dup code.
1 parent 11c9839 commit 34237e4

File tree

1 file changed

+11
-38
lines changed
  • torchvision/models/detection

1 file changed

+11
-38
lines changed

torchvision/models/detection/ssd.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,28 @@
11
import torch
22
import torch.nn.functional as F
33

4+
from collections import OrderedDict
45
from torch import nn, Tensor
56
from typing import Any, Dict, List, Optional, Tuple
67

78
from . import _utils as det_utils
89
from .anchor_utils import DBoxGenerator
9-
from .transform import GeneralizedRCNNTransform
1010
from .backbone_utils import _validate_trainable_layers
11+
from .transform import GeneralizedRCNNTransform
1112
from .. import vgg
1213

14+
from .retinanet import RetinaNet, RetinaNetHead # TODO: Refactor both to inherit properly
15+
1316

1417
__all__ = ['SSD']
1518

1619

17-
class SSDHead(nn.Module):
18-
# TODO: Similar to RetinaNetHead. Perhaps abstract and reuse for one-shot detectors.
20+
class SSDHead(RetinaNetHead):
1921
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
20-
super().__init__()
22+
nn.Module.__init__(self)
2123
self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes)
2224
self.regression_head = SSDRegressionHead(in_channels, num_anchors)
2325

24-
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor],
25-
matched_idxs: List[Tensor]) -> Dict[str, Tensor]:
26-
return {
27-
'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
28-
'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
29-
}
30-
31-
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
32-
return {
33-
'cls_logits': self.classification_head(x),
34-
'bbox_regression': self.regression_head(x)
35-
}
36-
3726

3827
class SSDClassificationHead(nn.Module):
3928
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
@@ -65,7 +54,7 @@ def forward(self, x: List[Tensor]) -> Tensor:
6554
pass
6655

6756

68-
class SSD(nn.Module):
57+
class SSD(RetinaNet):
6958
def __init__(self, backbone: nn.Module, num_classes: int,
7059
size: int = 300, image_mean: Optional[List[float]] = None, image_std: Optional[List[float]] = None,
7160
aspect_ratios: Optional[List[List[int]]] = None,
@@ -74,15 +63,15 @@ def __init__(self, backbone: nn.Module, num_classes: int,
7463
detections_per_img: int = 200,
7564
iou_thresh: float = 0.5,
7665
topk_candidates: int = 400):
77-
super().__init__()
66+
nn.Module.__init__(self)
7867

7968
if aspect_ratios is None:
8069
aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]]
8170

8271
# Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
8372
device = next(backbone.parameters()).device
8473
tmp_img = torch.empty((1, 3, size, size), device=device)
85-
tmp_sizes = [x.size() for x in backbone(tmp_img)]
74+
tmp_sizes = [x.size() for x in backbone(tmp_img).values()]
8675
out_channels = [x[1] for x in tmp_sizes]
8776
feature_map_sizes = [x[2] for x in tmp_sizes]
8877

@@ -118,26 +107,10 @@ def __init__(self, backbone: nn.Module, num_classes: int,
118107
# used only on torchscript mode
119108
self._has_warned = False
120109

121-
@torch.jit.unused
122-
def eager_outputs(self, losses, detections):
123-
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
124-
if self.training:
125-
return losses
126-
127-
return detections
128-
129110
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor],
130111
anchors: List[Tensor]) -> Dict[str, Tensor]:
131112
pass
132113

133-
def postprocess_detections(self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]],
134-
image_shapes: List[Tuple[int, int]]) -> List[Dict[str, Tensor]]:
135-
pass
136-
137-
def forward(self, images: List[Tensor],
138-
targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
139-
pass
140-
141114

142115
class SSDFeatureExtractorVGG(nn.Module):
143116
# TODO: That's the SSD300 extractor. handle the SDD500 case as well. See page 11, footernote 5.
@@ -188,7 +161,7 @@ def __init__(self, backbone: nn.Module):
188161
nn.ReLU(inplace=True),
189162
)
190163

191-
def forward(self, x: Tensor) -> List[Tensor]:
164+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
192165
# L2 regularization + Rescaling of 1st block's feature map
193166
x = self.block1(x)
194167
rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
@@ -199,7 +172,7 @@ def forward(self, x: Tensor) -> List[Tensor]:
199172
x = block(x)
200173
output.append(x)
201174

202-
return output
175+
return OrderedDict(((str(i), v) for i, v in enumerate(output)))
203176

204177

205178
def _vgg_backbone(backbone_name: str, pretrained: bool, trainable_layers: int = 3):

0 commit comments

Comments
 (0)