Skip to content

Commit 06477d6

Browse files
authored
Merge branch 'master' into models/ssd
2 parents 365d1ef + 7c35e13 commit 06477d6

File tree

1 file changed

+8
-23
lines changed

1 file changed

+8
-23
lines changed

torchvision/models/detection/anchor_utils.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from torch import nn, Tensor
44

5-
from typing import List, Optional
5+
from typing import List
66
from .image_list import ImageList
77

88

@@ -28,7 +28,7 @@ class AnchorGenerator(nn.Module):
2828
"""
2929

3030
__annotations__ = {
31-
"cell_anchors": Optional[List[torch.Tensor]],
31+
"cell_anchors": List[torch.Tensor],
3232
}
3333

3434
def __init__(
@@ -48,7 +48,8 @@ def __init__(
4848

4949
self.sizes = sizes
5050
self.aspect_ratios = aspect_ratios
51-
self.cell_anchors = None
51+
self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
52+
for size, aspect_ratio in zip(sizes, aspect_ratios)]
5253

5354
# TODO: https://github.com/pytorch/pytorch/issues/26792
5455
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
@@ -68,24 +69,8 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype:
6869
return base_anchors.round()
6970

7071
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
71-
if self.cell_anchors is not None:
72-
cell_anchors = self.cell_anchors
73-
assert cell_anchors is not None
74-
# suppose that all anchors have the same device
75-
# which is a valid assumption in the current state of the codebase
76-
if cell_anchors[0].device == device:
77-
return
78-
79-
cell_anchors = [
80-
self.generate_anchors(
81-
sizes,
82-
aspect_ratios,
83-
dtype,
84-
device
85-
)
86-
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
87-
]
88-
self.cell_anchors = cell_anchors
72+
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device)
73+
for cell_anchor in self.cell_anchors]
8974

9075
def num_anchors_per_location(self):
9176
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
@@ -131,15 +116,15 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
131116
return anchors
132117

133118
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
134-
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
119+
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
135120
image_size = image_list.tensors.shape[-2:]
136121
dtype, device = feature_maps[0].dtype, feature_maps[0].device
137122
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
138123
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
139124
self.set_cell_anchors(dtype, device)
140125
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
141126
anchors: List[List[torch.Tensor]] = []
142-
for i in range(len(image_list.image_sizes)):
127+
for _ in range(len(image_list.image_sizes)):
143128
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
144129
anchors.append(anchors_in_image)
145130
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]

0 commit comments

Comments
 (0)