Closed
Description
a question about the augmention (https://github.com/pytorch/vision/blob/main/references/detection/transforms.py) of SSD:
class RandomZoomOut(nn.Module):
def __init__(
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
):
super().__init__()
if fill is None:
fill = [0.0, 0.0, 0.0]
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
@torch.jit.unused
def _get_fill_value(self, is_pil):
# type: (bool) -> int
# We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
if torch.rand(1) < self.p:
return image, target
orig_w, orig_h = F.get_image_size(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
if torch.jit.is_scripting():
fill = 0
else:
fill = self._get_fill_value(F._is_pil_image(image))
image = F.pad(image, [left, top, right, bottom], fill=fill)
# maybe the following code is redundant?
if isinstance(image, torch.Tensor):
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
..., :, (left + orig_w) :
] = v
if target is not None:
target["boxes"][:, 0::2] += left
target["boxes"][:, 1::2] += top
return image, target
since the operation of F.pad has pad the image, why you do another fill operation for torch.Tensor?
cc @datumbox
Metadata
Metadata
Assignees
Labels
No labels