1
1
import torch
2
2
from torch import nn , Tensor
3
3
4
- from typing import List , Optional
4
+ from typing import List
5
5
from .image_list import ImageList
6
6
7
7
@@ -27,7 +27,7 @@ class AnchorGenerator(nn.Module):
27
27
"""
28
28
29
29
__annotations__ = {
30
- "cell_anchors" : Optional [ List [torch .Tensor ] ],
30
+ "cell_anchors" : List [torch .Tensor ],
31
31
}
32
32
33
33
def __init__ (
@@ -47,7 +47,8 @@ def __init__(
47
47
48
48
self .sizes = sizes
49
49
self .aspect_ratios = aspect_ratios
50
- self .cell_anchors = None
50
+ self .cell_anchors = [self .generate_anchors (size , aspect_ratio )
51
+ for size , aspect_ratio in zip (sizes , aspect_ratios )]
51
52
52
53
# TODO: https://github.com/pytorch/pytorch/issues/26792
53
54
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
@@ -67,24 +68,8 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype:
67
68
return base_anchors .round ()
68
69
69
70
def set_cell_anchors (self , dtype : torch .dtype , device : torch .device ):
70
- if self .cell_anchors is not None :
71
- cell_anchors = self .cell_anchors
72
- assert cell_anchors is not None
73
- # suppose that all anchors have the same device
74
- # which is a valid assumption in the current state of the codebase
75
- if cell_anchors [0 ].device == device :
76
- return
77
-
78
- cell_anchors = [
79
- self .generate_anchors (
80
- sizes ,
81
- aspect_ratios ,
82
- dtype ,
83
- device
84
- )
85
- for sizes , aspect_ratios in zip (self .sizes , self .aspect_ratios )
86
- ]
87
- self .cell_anchors = cell_anchors
71
+ self .cell_anchors = [cell_anchor .to (dtype = dtype , device = device )
72
+ for cell_anchor in self .cell_anchors ]
88
73
89
74
def num_anchors_per_location (self ):
90
75
return [len (s ) * len (a ) for s , a in zip (self .sizes , self .aspect_ratios )]
@@ -130,15 +115,15 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
130
115
return anchors
131
116
132
117
def forward (self , image_list : ImageList , feature_maps : List [Tensor ]) -> List [Tensor ]:
133
- grid_sizes = list ( [feature_map .shape [- 2 :] for feature_map in feature_maps ])
118
+ grid_sizes = [feature_map .shape [- 2 :] for feature_map in feature_maps ]
134
119
image_size = image_list .tensors .shape [- 2 :]
135
120
dtype , device = feature_maps [0 ].dtype , feature_maps [0 ].device
136
121
strides = [[torch .tensor (image_size [0 ] // g [0 ], dtype = torch .int64 , device = device ),
137
122
torch .tensor (image_size [1 ] // g [1 ], dtype = torch .int64 , device = device )] for g in grid_sizes ]
138
123
self .set_cell_anchors (dtype , device )
139
124
anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides )
140
125
anchors : List [List [torch .Tensor ]] = []
141
- for i in range (len (image_list .image_sizes )):
126
+ for _ in range (len (image_list .image_sizes )):
142
127
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps ]
143
128
anchors .append (anchors_in_image )
144
129
anchors = [torch .cat (anchors_per_image ) for anchors_per_image in anchors ]
0 commit comments