Skip to content

Commit 23f4e0f

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
[fbsync] Removed caching from AnchorGenerator (#3745)
Reviewed By: NicolasHug Differential Revision: D28169141 fbshipit-source-id: ad444fd3ce613e6eee398301a99e330b18055ada
1 parent 7f2ae3f commit 23f4e0f

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

torchvision/models/detection/anchor_utils.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import nn, Tensor
33

4-
from typing import List, Optional, Dict
4+
from typing import List, Optional
55
from .image_list import ImageList
66

77

@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
2828

2929
__annotations__ = {
3030
"cell_anchors": Optional[List[torch.Tensor]],
31-
"_cache": Dict[str, List[torch.Tensor]]
3231
}
3332

3433
def __init__(
@@ -49,7 +48,6 @@ def __init__(
4948
self.sizes = sizes
5049
self.aspect_ratios = aspect_ratios
5150
self.cell_anchors = None
52-
self._cache = {}
5351

5452
# TODO: https://github.com/pytorch/pytorch/issues/26792
5553
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
@@ -131,27 +129,17 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
131129

132130
return anchors
133131

134-
def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
135-
key = str(grid_sizes) + str(strides)
136-
if key in self._cache:
137-
return self._cache[key]
138-
anchors = self.grid_anchors(grid_sizes, strides)
139-
self._cache[key] = anchors
140-
return anchors
141-
142132
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
143133
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
144134
image_size = image_list.tensors.shape[-2:]
145135
dtype, device = feature_maps[0].dtype, feature_maps[0].device
146136
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
147137
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
148138
self.set_cell_anchors(dtype, device)
149-
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
139+
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
150140
anchors: List[List[torch.Tensor]] = []
151141
for i in range(len(image_list.image_sizes)):
152142
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
153143
anchors.append(anchors_in_image)
154144
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
155-
# Clear the cache in case that memory leaks.
156-
self._cache.clear()
157145
return anchors

0 commit comments

Comments
 (0)