Skip to content

Commit 60318d9

Browse files
committed
Introduce small score threshold on rpn
1 parent 3d60f49 commit 60318d9

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

test/test_models_detection_negative_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_targets_to_anchors(self):
4444
rpn_anchor_generator, rpn_head,
4545
0.5, 0.3,
4646
256, 0.5,
47-
2000, 2000, 0.7)
47+
2000, 2000, 0.7, 0.05)
4848

4949
labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
5050

test/test_onnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,14 @@ def _init_test_rpn(self):
197197
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
198198
rpn_post_nms_top_n = dict(training=2000, testing=1000)
199199
rpn_nms_thresh = 0.7
200+
rpn_score_thresh = 0.05
200201

201202
rpn = RegionProposalNetwork(
202203
rpn_anchor_generator, rpn_head,
203204
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
204205
rpn_batch_size_per_image, rpn_positive_fraction,
205-
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
206+
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
207+
score_thresh=rpn_score_thresh)
206208
return rpn
207209

208210
def _init_test_roi_heads_faster_rcnn(self):

torchvision/models/detection/faster_rcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(self, backbone, num_classes=None,
153153
rpn_nms_thresh=0.7,
154154
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
155155
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
156+
rpn_score_thresh=0.05,
156157
# Box parameters
157158
box_roi_pool=None, box_head=None, box_predictor=None,
158159
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
@@ -197,7 +198,7 @@ def __init__(self, backbone, num_classes=None,
197198
rpn_anchor_generator, rpn_head,
198199
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
199200
rpn_batch_size_per_image, rpn_positive_fraction,
200-
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
201+
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, rpn_score_thresh)
201202

202203
if box_roi_pool is None:
203204
box_roi_pool = MultiScaleRoIAlign(

torchvision/models/detection/rpn.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(self,
141141
fg_iou_thresh, bg_iou_thresh,
142142
batch_size_per_image, positive_fraction,
143143
#
144-
pre_nms_top_n, post_nms_top_n, nms_thresh):
144+
pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
145145
super(RegionProposalNetwork, self).__init__()
146146
self.anchor_generator = anchor_generator
147147
self.head = head
@@ -163,6 +163,7 @@ def __init__(self,
163163
self._pre_nms_top_n = pre_nms_top_n
164164
self._post_nms_top_n = post_nms_top_n
165165
self.nms_thresh = nms_thresh
166+
self.score_thresh = score_thresh
166167
self.min_size = 1e-3
167168

168169
def pre_nms_top_n(self):
@@ -251,17 +252,28 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
251252
levels = levels[batch_idx, top_n_idx]
252253
proposals = proposals[batch_idx, top_n_idx]
253254

255+
objectness_prob = F.sigmoid(objectness)
256+
254257
final_boxes = []
255258
final_scores = []
256-
for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
259+
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
257260
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
261+
262+
# remove small boxes
258263
keep = box_ops.remove_small_boxes(boxes, self.min_size)
259264
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
265+
266+
# remove low scoring boxes
267+
inds = torch.where(scores > self.score_thresh)[0]
268+
boxes, scores, lvl = boxes[inds], scores[inds], lvl[inds]
269+
260270
# non-maximum suppression, independently done per level
261271
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
272+
262273
# keep only topk scoring predictions
263274
keep = keep[:self.post_nms_top_n()]
264275
boxes, scores = boxes[keep], scores[keep]
276+
265277
final_boxes.append(boxes)
266278
final_scores.append(scores)
267279
return final_boxes, final_scores

0 commit comments

Comments
 (0)