4
4
from torch import nn , Tensor
5
5
from typing import Any , Dict , List , Optional , Tuple
6
6
7
+ from . import _utils as det_utils
7
8
from .anchor_utils import DBoxGenerator
9
+ from .transform import GeneralizedRCNNTransform
8
10
from .backbone_utils import _validate_trainable_layers
9
11
from .. import vgg
10
12
@@ -64,8 +66,14 @@ def forward(self, x: List[Tensor]) -> Tensor:
64
66
65
67
66
68
class SSD (nn .Module ):
67
- def __init__ (self , backbone : nn .Module , num_classes : int , size : int = 300 ,
68
- aspect_ratios : Optional [List [List [int ]]] = None ):
69
+ def __init__ (self , backbone : nn .Module , num_classes : int ,
70
+ size : int = 300 , image_mean : Optional [List [float ]] = None , image_std : Optional [List [float ]] = None ,
71
+ aspect_ratios : Optional [List [List [int ]]] = None ,
72
+ score_thresh : float = 0.01 ,
73
+ nms_thresh : float = 0.45 ,
74
+ detections_per_img : int = 200 ,
75
+ iou_thresh : float = 0.5 ,
76
+ topk_candidates : int = 400 ):
69
77
super ().__init__ ()
70
78
71
79
if aspect_ratios is None :
@@ -81,14 +89,34 @@ def __init__(self, backbone: nn.Module, num_classes: int, size: int = 300,
81
89
assert len (feature_map_sizes ) == len (aspect_ratios )
82
90
83
91
self .backbone = backbone
84
- self .num_classes = num_classes
85
- self .aspect_ratios = aspect_ratios
86
92
87
93
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
88
94
num_anchors = [2 + 2 * len (r ) for r in aspect_ratios ]
89
95
self .head = SSDHead (out_channels , num_anchors , num_classes )
90
96
91
- self .dbox_generator = DBoxGenerator (size , feature_map_sizes , aspect_ratios )
97
+ self .anchor_generator = DBoxGenerator (size , feature_map_sizes , aspect_ratios )
98
+
99
+ self .proposal_matcher = det_utils .Matcher (
100
+ iou_thresh ,
101
+ iou_thresh ,
102
+ allow_low_quality_matches = True ,
103
+ )
104
+
105
+ self .box_coder = det_utils .BoxCoder (weights = (10. , 10. , 5. , 5. ))
106
+
107
+ if image_mean is None :
108
+ image_mean = [0.485 , 0.456 , 0.406 ]
109
+ if image_std is None :
110
+ image_std = [0.229 , 0.224 , 0.225 ]
111
+ self .transform = GeneralizedRCNNTransform (size , size , image_mean , image_std )
112
+
113
+ self .score_thresh = score_thresh
114
+ self .nms_thresh = nms_thresh
115
+ self .detections_per_img = detections_per_img
116
+ self .topk_candidates = topk_candidates
117
+
118
+ # used only on torchscript mode
119
+ self ._has_warned = False
92
120
93
121
@torch .jit .unused
94
122
def eager_outputs (self , losses , detections ):
0 commit comments