Skip to content

Commit 11c9839

Browse files
committed
Adding parameters and reusing objects in constructor.
1 parent 327e004 commit 11c9839

File tree

1 file changed

+33
-5
lines changed
  • torchvision/models/detection

1 file changed

+33
-5
lines changed

torchvision/models/detection/ssd.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from torch import nn, Tensor
55
from typing import Any, Dict, List, Optional, Tuple
66

7+
from . import _utils as det_utils
78
from .anchor_utils import DBoxGenerator
9+
from .transform import GeneralizedRCNNTransform
810
from .backbone_utils import _validate_trainable_layers
911
from .. import vgg
1012

@@ -64,8 +66,14 @@ def forward(self, x: List[Tensor]) -> Tensor:
6466

6567

6668
class SSD(nn.Module):
67-
def __init__(self, backbone: nn.Module, num_classes: int, size: int = 300,
68-
aspect_ratios: Optional[List[List[int]]] = None):
69+
def __init__(self, backbone: nn.Module, num_classes: int,
70+
size: int = 300, image_mean: Optional[List[float]] = None, image_std: Optional[List[float]] = None,
71+
aspect_ratios: Optional[List[List[int]]] = None,
72+
score_thresh: float = 0.01,
73+
nms_thresh: float = 0.45,
74+
detections_per_img: int = 200,
75+
iou_thresh: float = 0.5,
76+
topk_candidates: int = 400):
6977
super().__init__()
7078

7179
if aspect_ratios is None:
@@ -81,14 +89,34 @@ def __init__(self, backbone: nn.Module, num_classes: int, size: int = 300,
8189
assert len(feature_map_sizes) == len(aspect_ratios)
8290

8391
self.backbone = backbone
84-
self.num_classes = num_classes
85-
self.aspect_ratios = aspect_ratios
8692

8793
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
8894
num_anchors = [2 + 2 * len(r) for r in aspect_ratios]
8995
self.head = SSDHead(out_channels, num_anchors, num_classes)
9096

91-
self.dbox_generator = DBoxGenerator(size, feature_map_sizes, aspect_ratios)
97+
self.anchor_generator = DBoxGenerator(size, feature_map_sizes, aspect_ratios)
98+
99+
self.proposal_matcher = det_utils.Matcher(
100+
iou_thresh,
101+
iou_thresh,
102+
allow_low_quality_matches=True,
103+
)
104+
105+
self.box_coder = det_utils.BoxCoder(weights=(10., 10., 5., 5.))
106+
107+
if image_mean is None:
108+
image_mean = [0.485, 0.456, 0.406]
109+
if image_std is None:
110+
image_std = [0.229, 0.224, 0.225]
111+
self.transform = GeneralizedRCNNTransform(size, size, image_mean, image_std)
112+
113+
self.score_thresh = score_thresh
114+
self.nms_thresh = nms_thresh
115+
self.detections_per_img = detections_per_img
116+
self.topk_candidates = topk_candidates
117+
118+
# used only on torchscript mode
119+
self._has_warned = False
92120

93121
@torch.jit.unused
94122
def eager_outputs(self, losses, detections):

0 commit comments

Comments
 (0)