Skip to content

Commit 86a14cb

Browse files
Add topk min function for trace and onnx (#5310)
* Add topk minimizer function to _utils * Apply ufmt formatting * Apply min function for tracing and scripting * Add type ignore to avoid cast * fix flake * Fix python_type_check Co-authored-by: Prabhat Roy <[email protected]>
1 parent 8097370 commit 86a14cb

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

torchvision/models/detection/_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,27 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
468468
model.train()
469469

470470
return out_channels
471+
472+
473+
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor:
474+
"""
475+
ONNX spec requires the k-value to be less than or equal to the number of inputs along
476+
provided dim. Certain models use the number of elements along a particular axis instead of K
477+
if K exceeds the number of elements along that axis. Previously, python's min() function was
478+
used to determine whether to use the provided k-value or the specified dim axis value.
479+
480+
However in cases where the model is being exported in tracing mode, python min() is
481+
static causing the model to be traced incorrectly and eventually fail at the topk node.
482+
In order to avoid this situation, in tracing mode, torch.min() is used instead.
483+
484+
Args:
485+
input (Tensor): The orignal input tensor.
486+
orig_kval (int): The provided k-value.
487+
axis(int): Axis along which we retreive the input size.
488+
489+
Returns:
490+
min_kval (Tensor): Appropriately selected k-value.
491+
"""
492+
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
493+
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
494+
return min_kval # type: ignore[arg-type]

torchvision/models/detection/fcos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def postprocess_detections(
501501
topk_idxs = torch.where(keep_idxs)[0]
502502

503503
# keep only topk scoring predictions
504-
num_topk = min(self.topk_candidates, topk_idxs.size(0))
504+
num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
505505
scores_per_level, idxs = scores_per_level.topk(num_topk)
506506
topk_idxs = topk_idxs[idxs]
507507

torchvision/models/detection/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
436436
topk_idxs = torch.where(keep_idxs)[0]
437437

438438
# keep only topk scoring predictions
439-
num_topk = min(self.topk_candidates, topk_idxs.size(0))
439+
num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
440440
scores_per_level, idxs = scores_per_level.topk(num_topk)
441441
topk_idxs = topk_idxs[idxs]
442442

torchvision/models/detection/rpn.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import List, Optional, Dict, Tuple, cast
1+
from typing import List, Optional, Dict, Tuple
22

33
import torch
4-
import torchvision
54
from torch import nn, Tensor
65
from torch.nn import functional as F
76
from torchvision.ops import boxes as box_ops
@@ -13,17 +12,6 @@
1312
from .image_list import ImageList
1413

1514

16-
@torch.jit.unused
17-
def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
18-
from torch.onnx import operators
19-
20-
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
21-
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
22-
23-
# for mypy we cast at runtime
24-
return cast(int, num_anchors), cast(int, pre_nms_top_n)
25-
26-
2715
class RPNHead(nn.Module):
2816
"""
2917
Adds a simple RPN Head with classification and regression heads
@@ -206,11 +194,8 @@ def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -
206194
r = []
207195
offset = 0
208196
for ob in objectness.split(num_anchors_per_level, 1):
209-
if torchvision._is_tracing():
210-
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
211-
else:
212-
num_anchors = ob.shape[1]
213-
pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
197+
num_anchors = ob.shape[1]
198+
pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
214199
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
215200
r.append(top_n_idx + offset)
216201
offset += num_anchors

torchvision/models/detection/ssd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def postprocess_detections(
407407
box = boxes[keep_idxs]
408408

409409
# keep only topk scoring predictions
410-
num_topk = min(self.topk_candidates, score.size(0))
410+
num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
411411
score, idxs = score.topk(num_topk)
412412
box = box[idxs]
413413

0 commit comments

Comments
 (0)