Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,27 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
model.train()

return out_channels


def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor:
"""
ONNX spec requires the k-value to be less than or equal to the number of inputs along
provided dim. Certain models use the number of elements along a particular axis instead of K
if K exceeds the number of elements along that axis. Previously, python's min() function was
used to determine whether to use the provided k-value or the specified dim axis value.

However in cases where the model is being exported in tracing mode, python min() is
static causing the model to be traced incorrectly and eventually fail at the topk node.
In order to avoid this situation, in tracing mode, torch.min() is used instead.

Args:
input (Tensor): The orignal input tensor.
orig_kval (int): The provided k-value.
axis(int): Axis along which we retreive the input size.

Returns:
min_kval (Tensor): Appropriately selected k-value.
"""
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return min_kval # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prabhat00155 @shubhambhokare1 There might be performance implications here that we need to benchmark thoroughly. If we have such benchmarks we should post them in this PR to make them easily accessible.

If we don't I propose either to revert the PR until we measure it. Another alternative is to do the onnx friendly approach only when we are tracing and do the previous simpler estimation in all other cases.

2 changes: 1 addition & 1 deletion torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def postprocess_detections(
topk_idxs = torch.where(keep_idxs)[0]

# keep only topk scoring predictions
num_topk = min(self.topk_candidates, topk_idxs.size(0))
num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs]

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
topk_idxs = torch.where(keep_idxs)[0]

# keep only topk scoring predictions
num_topk = min(self.topk_candidates, topk_idxs.size(0))
num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs]

Expand Down
21 changes: 3 additions & 18 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Optional, Dict, Tuple, cast
from typing import List, Optional, Dict, Tuple

import torch
import torchvision
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops
Expand All @@ -13,17 +12,6 @@
from .image_list import ImageList


@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
from torch.onnx import operators

num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))

# for mypy we cast at runtime
return cast(int, num_anchors), cast(int, pre_nms_top_n)


class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
Expand Down Expand Up @@ -206,11 +194,8 @@ def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
else:
num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
num_anchors = ob.shape[1]
pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset)
offset += num_anchors
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def postprocess_detections(
box = boxes[keep_idxs]

# keep only topk scoring predictions
num_topk = min(self.topk_candidates, score.size(0))
num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
score, idxs = score.topk(num_topk)
box = box[idxs]

Expand Down