diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 4e721674537..e62e34f3d54 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -698,7 +698,8 @@ def postprocess_detections( labels = labels.view(1, -1).expand_as(scores) # remove predictions with the background label - boxes = boxes[:, 1:] + boxes = boxes.reshape(-1, num_classes, 4) + boxes = boxes[..., 1:, :] scores = scores[:, 1:] labels = labels[:, 1:]