|
| 1 | +# https://raw.githubusercontent.com/pytorch/vision/ae30df455405fb56946425bf3f3c318280b0a7ae/torchvision/models/detection/anchor_utils.py |
| 2 | + |
| 3 | +import math |
| 4 | +from typing import List, Optional |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch import nn, Tensor |
| 8 | + |
| 9 | +from .image_list import ImageList |
| 10 | + |
| 11 | + |
| 12 | +class AnchorGenerator(nn.Module): |
| 13 | + """ |
| 14 | + Module that generates anchors for a set of feature maps and |
| 15 | + image sizes. |
| 16 | +
|
| 17 | + The module support computing anchors at multiple sizes and aspect ratios |
| 18 | + per feature map. This module assumes aspect ratio = height / width for |
| 19 | + each anchor. |
| 20 | +
|
| 21 | + sizes and aspect_ratios should have the same number of elements, and it should |
| 22 | + correspond to the number of feature maps. |
| 23 | +
|
| 24 | + sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, |
| 25 | + and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors |
| 26 | + per spatial location for feature map i. |
| 27 | +
|
| 28 | + Args: |
| 29 | + sizes (Tuple[Tuple[int]]): |
| 30 | + aspect_ratios (Tuple[Tuple[float]]): |
| 31 | + """ |
| 32 | + |
| 33 | + __annotations__ = { |
| 34 | + "cell_anchors": List[torch.Tensor], |
| 35 | + } |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + sizes=((128, 256, 512),), |
| 40 | + aspect_ratios=((0.5, 1.0, 2.0),), |
| 41 | + ): |
| 42 | + super().__init__() |
| 43 | + |
| 44 | + if not isinstance(sizes[0], (list, tuple)): |
| 45 | + # TODO change this |
| 46 | + sizes = tuple((s,) for s in sizes) |
| 47 | + if not isinstance(aspect_ratios[0], (list, tuple)): |
| 48 | + aspect_ratios = (aspect_ratios,) * len(sizes) |
| 49 | + |
| 50 | + self.sizes = sizes |
| 51 | + self.aspect_ratios = aspect_ratios |
| 52 | + self.cell_anchors = [ |
| 53 | + self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios) |
| 54 | + ] |
| 55 | + |
| 56 | + # TODO: https://github.com/pytorch/pytorch/issues/26792 |
| 57 | + # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. |
| 58 | + # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) |
| 59 | + # This method assumes aspect ratio = height / width for an anchor. |
| 60 | + def generate_anchors( |
| 61 | + self, |
| 62 | + scales: List[int], |
| 63 | + aspect_ratios: List[float], |
| 64 | + dtype: torch.dtype = torch.float32, |
| 65 | + device: torch.device = torch.device("cpu"), |
| 66 | + ) -> Tensor: |
| 67 | + scales = torch.as_tensor(scales, dtype=dtype, device=device) |
| 68 | + aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) |
| 69 | + h_ratios = torch.sqrt(aspect_ratios) |
| 70 | + w_ratios = 1 / h_ratios |
| 71 | + |
| 72 | + ws = (w_ratios[:, None] * scales[None, :]).view(-1) |
| 73 | + hs = (h_ratios[:, None] * scales[None, :]).view(-1) |
| 74 | + |
| 75 | + base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 |
| 76 | + return base_anchors.round() |
| 77 | + |
| 78 | + def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): |
| 79 | + self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] |
| 80 | + |
| 81 | + def num_anchors_per_location(self) -> List[int]: |
| 82 | + return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] |
| 83 | + |
| 84 | + # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), |
| 85 | + # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. |
| 86 | + def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: |
| 87 | + anchors = [] |
| 88 | + cell_anchors = self.cell_anchors |
| 89 | + torch._assert(cell_anchors is not None, "cell_anchors should not be None") |
| 90 | + torch._assert( |
| 91 | + len(grid_sizes) == len(strides) == len(cell_anchors), |
| 92 | + "Anchors should be Tuple[Tuple[int]] because each feature " |
| 93 | + "map could potentially have different sizes and aspect ratios. " |
| 94 | + "There needs to be a match between the number of " |
| 95 | + "feature maps passed and the number of sizes / aspect ratios specified.", |
| 96 | + ) |
| 97 | + |
| 98 | + for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors): |
| 99 | + grid_height, grid_width = size |
| 100 | + stride_height, stride_width = stride |
| 101 | + device = base_anchors.device |
| 102 | + |
| 103 | + # For output anchor, compute [x_center, y_center, x_center, y_center] |
| 104 | + shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width |
| 105 | + shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height |
| 106 | + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") |
| 107 | + shift_x = shift_x.reshape(-1) |
| 108 | + shift_y = shift_y.reshape(-1) |
| 109 | + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) |
| 110 | + |
| 111 | + # For every (base anchor, output anchor) pair, |
| 112 | + # offset each zero-centered base anchor by the center of the output anchor. |
| 113 | + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) |
| 114 | + |
| 115 | + return anchors |
| 116 | + |
| 117 | + def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: |
| 118 | + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] |
| 119 | + image_size = image_list.tensors.shape[-2:] |
| 120 | + dtype, device = feature_maps[0].dtype, feature_maps[0].device |
| 121 | + strides = [ |
| 122 | + [ |
| 123 | + torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]), |
| 124 | + torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]), |
| 125 | + ] |
| 126 | + for g in grid_sizes |
| 127 | + ] |
| 128 | + self.set_cell_anchors(dtype, device) |
| 129 | + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) |
| 130 | + anchors: List[List[torch.Tensor]] = [] |
| 131 | + for _ in range(len(image_list.image_sizes)): |
| 132 | + anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] |
| 133 | + anchors.append(anchors_in_image) |
| 134 | + anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] |
| 135 | + return anchors |
| 136 | + |
| 137 | + |
| 138 | +class DefaultBoxGenerator(nn.Module): |
| 139 | + """ |
| 140 | + This module generates the default boxes of SSD for a set of feature maps and image sizes. |
| 141 | +
|
| 142 | + Args: |
| 143 | + aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map. |
| 144 | + min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation |
| 145 | + of the scales of each feature map. It is used only if the ``scales`` parameter is not provided. |
| 146 | + max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation |
| 147 | + of the scales of each feature map. It is used only if the ``scales`` parameter is not provided. |
| 148 | + scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using |
| 149 | + the ``min_ratio`` and ``max_ratio`` parameters. |
| 150 | + steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided |
| 151 | + it will be estimated from the data. |
| 152 | + clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping |
| 153 | + is applied while the boxes are encoded in format ``(cx, cy, w, h)``. |
| 154 | + """ |
| 155 | + |
| 156 | + def __init__( |
| 157 | + self, |
| 158 | + aspect_ratios: List[List[int]], |
| 159 | + min_ratio: float = 0.15, |
| 160 | + max_ratio: float = 0.9, |
| 161 | + scales: Optional[List[float]] = None, |
| 162 | + steps: Optional[List[int]] = None, |
| 163 | + clip: bool = True, |
| 164 | + ): |
| 165 | + super().__init__() |
| 166 | + if steps is not None and len(aspect_ratios) != len(steps): |
| 167 | + raise ValueError("aspect_ratios and steps should have the same length") |
| 168 | + self.aspect_ratios = aspect_ratios |
| 169 | + self.steps = steps |
| 170 | + self.clip = clip |
| 171 | + num_outputs = len(aspect_ratios) |
| 172 | + |
| 173 | + # Estimation of default boxes scales |
| 174 | + if scales is None: |
| 175 | + if num_outputs > 1: |
| 176 | + range_ratio = max_ratio - min_ratio |
| 177 | + self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)] |
| 178 | + self.scales.append(1.0) |
| 179 | + else: |
| 180 | + self.scales = [min_ratio, max_ratio] |
| 181 | + else: |
| 182 | + self.scales = scales |
| 183 | + |
| 184 | + self._wh_pairs = self._generate_wh_pairs(num_outputs) |
| 185 | + |
| 186 | + def _generate_wh_pairs( |
| 187 | + self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu") |
| 188 | + ) -> List[Tensor]: |
| 189 | + _wh_pairs: List[Tensor] = [] |
| 190 | + for k in range(num_outputs): |
| 191 | + # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k |
| 192 | + s_k = self.scales[k] |
| 193 | + s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1]) |
| 194 | + wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]] |
| 195 | + |
| 196 | + # Adding 2 pairs for each aspect ratio of the feature map k |
| 197 | + for ar in self.aspect_ratios[k]: |
| 198 | + sq_ar = math.sqrt(ar) |
| 199 | + w = self.scales[k] * sq_ar |
| 200 | + h = self.scales[k] / sq_ar |
| 201 | + wh_pairs.extend([[w, h], [h, w]]) |
| 202 | + |
| 203 | + _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device)) |
| 204 | + return _wh_pairs |
| 205 | + |
| 206 | + def num_anchors_per_location(self) -> List[int]: |
| 207 | + # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map. |
| 208 | + return [2 + 2 * len(r) for r in self.aspect_ratios] |
| 209 | + |
| 210 | + # Default Boxes calculation based on page 6 of SSD paper |
| 211 | + def _grid_default_boxes( |
| 212 | + self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32 |
| 213 | + ) -> Tensor: |
| 214 | + default_boxes = [] |
| 215 | + for k, f_k in enumerate(grid_sizes): |
| 216 | + # Now add the default boxes for each width-height pair |
| 217 | + if self.steps is not None: |
| 218 | + x_f_k = image_size[1] / self.steps[k] |
| 219 | + y_f_k = image_size[0] / self.steps[k] |
| 220 | + else: |
| 221 | + y_f_k, x_f_k = f_k |
| 222 | + |
| 223 | + shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype) |
| 224 | + shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype) |
| 225 | + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") |
| 226 | + shift_x = shift_x.reshape(-1) |
| 227 | + shift_y = shift_y.reshape(-1) |
| 228 | + |
| 229 | + shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2) |
| 230 | + # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h) |
| 231 | + _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k] |
| 232 | + wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1) |
| 233 | + |
| 234 | + default_box = torch.cat((shifts, wh_pairs), dim=1) |
| 235 | + |
| 236 | + default_boxes.append(default_box) |
| 237 | + |
| 238 | + return torch.cat(default_boxes, dim=0) |
| 239 | + |
| 240 | + def __repr__(self) -> str: |
| 241 | + s = ( |
| 242 | + f"{self.__class__.__name__}(" |
| 243 | + f"aspect_ratios={self.aspect_ratios}" |
| 244 | + f", clip={self.clip}" |
| 245 | + f", scales={self.scales}" |
| 246 | + f", steps={self.steps}" |
| 247 | + ")" |
| 248 | + ) |
| 249 | + return s |
| 250 | + |
| 251 | + def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: |
| 252 | + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] |
| 253 | + image_size = image_list.tensors.shape[-2:] |
| 254 | + dtype, device = feature_maps[0].dtype, feature_maps[0].device |
| 255 | + default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype) |
| 256 | + default_boxes = default_boxes.to(device) |
| 257 | + |
| 258 | + dboxes = [] |
| 259 | + x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device) |
| 260 | + for _ in image_list.image_sizes: |
| 261 | + dboxes_in_image = default_boxes |
| 262 | + dboxes_in_image = torch.cat( |
| 263 | + [ |
| 264 | + (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size, |
| 265 | + (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size, |
| 266 | + ], |
| 267 | + -1, |
| 268 | + ) |
| 269 | + dboxes.append(dboxes_in_image) |
| 270 | + return dboxes |
| 271 | + |
| 272 | + |
| 273 | +def grid_offsets(grid_size: Tensor) -> Tensor: |
| 274 | + """Given a grid size, returns a tensor containing offsets to the grid cells. |
| 275 | +
|
| 276 | + Args: |
| 277 | + The width and height of the grid in a tensor. |
| 278 | +
|
| 279 | + Returns: |
| 280 | + A ``[height, width, 2]`` tensor containing the grid cell `(x, y)` offsets. |
| 281 | + """ |
| 282 | + x_range = torch.arange(grid_size[0].item(), device=grid_size.device) |
| 283 | + y_range = torch.arange(grid_size[1].item(), device=grid_size.device) |
| 284 | + grid_y, grid_x = torch.meshgrid([y_range, x_range], indexing="ij") |
| 285 | + return torch.stack((grid_x, grid_y), -1) |
| 286 | + |
| 287 | + |
| 288 | +def grid_centers(grid_size: Tensor) -> Tensor: |
| 289 | + """Given a grid size, returns a tensor containing coordinates to the centers of the grid cells. |
| 290 | +
|
| 291 | + Returns: |
| 292 | + A ``[height, width, 2]`` tensor containing coordinates to the centers of the grid cells. |
| 293 | + """ |
| 294 | + return grid_offsets(grid_size) + 0.5 |
| 295 | + |
| 296 | + |
| 297 | +@torch.jit.script |
| 298 | +def global_xy(xy: Tensor, image_size: Tensor) -> Tensor: |
| 299 | + """Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. |
| 300 | +
|
| 301 | + The predicted coordinates are interpreted as coordinates inside a grid cell whose width and height is 1. Adding |
| 302 | + offset to the cell, dividing by the grid size, and multiplying by the image size, we get global coordinates in the |
| 303 | + image scale. |
| 304 | +
|
| 305 | + The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based |
| 306 | + generator will loose track of e.g. ``xy.shape[1]`` and treat it as a Python variable and not a tensor. This will |
| 307 | + cause the dimension to be treated as a constant in the model, which prevents dynamic input sizes. |
| 308 | +
|
| 309 | + Args: |
| 310 | + xy: The predicted center coordinates before scaling. Values from zero to one in a tensor sized |
| 311 | + ``[batch_size, height, width, boxes_per_cell, 2]``. |
| 312 | + image_size: Width and height in a vector that will be used to scale the coordinates. |
| 313 | +
|
| 314 | + Returns: |
| 315 | + Global coordinates scaled to the size of the network input image, in a tensor with the same shape as the input |
| 316 | + tensor. |
| 317 | + """ |
| 318 | + height = xy.shape[1] |
| 319 | + width = xy.shape[2] |
| 320 | + grid_size = torch.tensor([width, height], device=xy.device) |
| 321 | + # Scripting requires explicit conversion to a floating point type. |
| 322 | + offset = grid_offsets(grid_size).to(xy.dtype).unsqueeze(2) # [height, width, 1, 2] |
| 323 | + scale = torch.true_divide(image_size, grid_size) |
| 324 | + return (xy + offset) * scale |
0 commit comments