Skip to content

Commit f766afa

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
[fbsync] Refactored set_cell_anchors() in AnchorGenerator (#3755)
Summary: * Refactored set_cell_anchors() in AnchorGenerator * Addressed review comment * Fixed test failure Reviewed By: NicolasHug Differential Revision: D28169121 fbshipit-source-id: 1b6d9bcec0b69eedc67761796bb110befdc377c9
1 parent 23f4e0f commit f766afa

File tree

1 file changed

+8
-23
lines changed

1 file changed

+8
-23
lines changed

torchvision/models/detection/anchor_utils.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import nn, Tensor
33

4-
from typing import List, Optional
4+
from typing import List
55
from .image_list import ImageList
66

77

@@ -27,7 +27,7 @@ class AnchorGenerator(nn.Module):
2727
"""
2828

2929
__annotations__ = {
30-
"cell_anchors": Optional[List[torch.Tensor]],
30+
"cell_anchors": List[torch.Tensor],
3131
}
3232

3333
def __init__(
@@ -47,7 +47,8 @@ def __init__(
4747

4848
self.sizes = sizes
4949
self.aspect_ratios = aspect_ratios
50-
self.cell_anchors = None
50+
self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
51+
for size, aspect_ratio in zip(sizes, aspect_ratios)]
5152

5253
# TODO: https://github.com/pytorch/pytorch/issues/26792
5354
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
@@ -67,24 +68,8 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype:
6768
return base_anchors.round()
6869

6970
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
70-
if self.cell_anchors is not None:
71-
cell_anchors = self.cell_anchors
72-
assert cell_anchors is not None
73-
# suppose that all anchors have the same device
74-
# which is a valid assumption in the current state of the codebase
75-
if cell_anchors[0].device == device:
76-
return
77-
78-
cell_anchors = [
79-
self.generate_anchors(
80-
sizes,
81-
aspect_ratios,
82-
dtype,
83-
device
84-
)
85-
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
86-
]
87-
self.cell_anchors = cell_anchors
71+
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device)
72+
for cell_anchor in self.cell_anchors]
8873

8974
def num_anchors_per_location(self):
9075
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
@@ -130,15 +115,15 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
130115
return anchors
131116

132117
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
133-
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
118+
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
134119
image_size = image_list.tensors.shape[-2:]
135120
dtype, device = feature_maps[0].dtype, feature_maps[0].device
136121
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
137122
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
138123
self.set_cell_anchors(dtype, device)
139124
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
140125
anchors: List[List[torch.Tensor]] = []
141-
for i in range(len(image_list.image_sizes)):
126+
for _ in range(len(image_list.image_sizes)):
142127
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
143128
anchors.append(anchors_in_image)
144129
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]

0 commit comments

Comments
 (0)