1
1
import torch
2
2
from torch import nn , Tensor
3
3
4
- from typing import List , Optional , Dict
4
+ from typing import List , Optional
5
5
from .image_list import ImageList
6
6
7
7
@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
28
28
29
29
__annotations__ = {
30
30
"cell_anchors" : Optional [List [torch .Tensor ]],
31
- "_cache" : Dict [str , List [torch .Tensor ]]
32
31
}
33
32
34
33
def __init__ (
@@ -49,7 +48,6 @@ def __init__(
49
48
self .sizes = sizes
50
49
self .aspect_ratios = aspect_ratios
51
50
self .cell_anchors = None
52
- self ._cache = {}
53
51
54
52
# TODO: https://github.com/pytorch/pytorch/issues/26792
55
53
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
@@ -131,27 +129,17 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
131
129
132
130
return anchors
133
131
134
- def cached_grid_anchors (self , grid_sizes : List [List [int ]], strides : List [List [Tensor ]]) -> List [Tensor ]:
135
- key = str (grid_sizes ) + str (strides )
136
- if key in self ._cache :
137
- return self ._cache [key ]
138
- anchors = self .grid_anchors (grid_sizes , strides )
139
- self ._cache [key ] = anchors
140
- return anchors
141
-
142
132
def forward (self , image_list : ImageList , feature_maps : List [Tensor ]) -> List [Tensor ]:
143
133
grid_sizes = list ([feature_map .shape [- 2 :] for feature_map in feature_maps ])
144
134
image_size = image_list .tensors .shape [- 2 :]
145
135
dtype , device = feature_maps [0 ].dtype , feature_maps [0 ].device
146
136
strides = [[torch .tensor (image_size [0 ] // g [0 ], dtype = torch .int64 , device = device ),
147
137
torch .tensor (image_size [1 ] // g [1 ], dtype = torch .int64 , device = device )] for g in grid_sizes ]
148
138
self .set_cell_anchors (dtype , device )
149
- anchors_over_all_feature_maps = self .cached_grid_anchors (grid_sizes , strides )
139
+ anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides )
150
140
anchors : List [List [torch .Tensor ]] = []
151
141
for i in range (len (image_list .image_sizes )):
152
142
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps ]
153
143
anchors .append (anchors_in_image )
154
144
anchors = [torch .cat (anchors_per_image ) for anchors_per_image in anchors ]
155
- # Clear the cache in case that memory leaks.
156
- self ._cache .clear ()
157
145
return anchors
0 commit comments