2
2
import torch
3
3
from torch import nn , Tensor
4
4
5
- from typing import List , Optional
5
+ from typing import List
6
6
from .image_list import ImageList
7
7
8
8
@@ -28,7 +28,7 @@ class AnchorGenerator(nn.Module):
28
28
"""
29
29
30
30
__annotations__ = {
31
- "cell_anchors" : Optional [ List [torch .Tensor ] ],
31
+ "cell_anchors" : List [torch .Tensor ],
32
32
}
33
33
34
34
def __init__ (
@@ -48,7 +48,8 @@ def __init__(
48
48
49
49
self .sizes = sizes
50
50
self .aspect_ratios = aspect_ratios
51
- self .cell_anchors = None
51
+ self .cell_anchors = [self .generate_anchors (size , aspect_ratio )
52
+ for size , aspect_ratio in zip (sizes , aspect_ratios )]
52
53
53
54
# TODO: https://github.com/pytorch/pytorch/issues/26792
54
55
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
@@ -68,24 +69,8 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype:
68
69
return base_anchors .round ()
69
70
70
71
def set_cell_anchors (self , dtype : torch .dtype , device : torch .device ):
71
- if self .cell_anchors is not None :
72
- cell_anchors = self .cell_anchors
73
- assert cell_anchors is not None
74
- # suppose that all anchors have the same device
75
- # which is a valid assumption in the current state of the codebase
76
- if cell_anchors [0 ].device == device :
77
- return
78
-
79
- cell_anchors = [
80
- self .generate_anchors (
81
- sizes ,
82
- aspect_ratios ,
83
- dtype ,
84
- device
85
- )
86
- for sizes , aspect_ratios in zip (self .sizes , self .aspect_ratios )
87
- ]
88
- self .cell_anchors = cell_anchors
72
+ self .cell_anchors = [cell_anchor .to (dtype = dtype , device = device )
73
+ for cell_anchor in self .cell_anchors ]
89
74
90
75
def num_anchors_per_location (self ):
91
76
return [len (s ) * len (a ) for s , a in zip (self .sizes , self .aspect_ratios )]
@@ -131,15 +116,15 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
131
116
return anchors
132
117
133
118
def forward (self , image_list : ImageList , feature_maps : List [Tensor ]) -> List [Tensor ]:
134
- grid_sizes = list ( [feature_map .shape [- 2 :] for feature_map in feature_maps ])
119
+ grid_sizes = [feature_map .shape [- 2 :] for feature_map in feature_maps ]
135
120
image_size = image_list .tensors .shape [- 2 :]
136
121
dtype , device = feature_maps [0 ].dtype , feature_maps [0 ].device
137
122
strides = [[torch .tensor (image_size [0 ] // g [0 ], dtype = torch .int64 , device = device ),
138
123
torch .tensor (image_size [1 ] // g [1 ], dtype = torch .int64 , device = device )] for g in grid_sizes ]
139
124
self .set_cell_anchors (dtype , device )
140
125
anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides )
141
126
anchors : List [List [torch .Tensor ]] = []
142
- for i in range (len (image_list .image_sizes )):
127
+ for _ in range (len (image_list .image_sizes )):
143
128
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps ]
144
129
anchors .append (anchors_in_image )
145
130
anchors = [torch .cat (anchors_per_image ) for anchors_per_image in anchors ]
0 commit comments