Skip to content
26 changes: 9 additions & 17 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,7 @@ def transform(bbox):
],
dtype=bbox.dtype,
)
return F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
)
return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format)

if bounding_box.ndim < 2:
bounding_box = [bounding_box]
Expand Down Expand Up @@ -554,26 +552,20 @@ def sample_inputs_affine_video():


def sample_inputs_convert_format_bounding_box():
formats = set(features.BoundingBoxFormat)
for bounding_box_loader in make_bounding_box_loaders(formats=formats):
old_format = bounding_box_loader.format
for params in combinations_grid(new_format=formats - {old_format}, copy=(True, False)):
yield ArgsKwargs(bounding_box_loader, old_format=old_format, **params)

formats = list(features.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)

def reference_convert_format_bounding_box(bounding_box, old_format, new_format, copy):
if not copy:
raise pytest.UsageError("Reference for `convert_format_bounding_box` only supports `copy=True`")

def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
return torchvision.ops.box_convert(
bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower()
)


def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 2 and kwargs.setdefault("copy", True):
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs


Expand All @@ -598,19 +590,19 @@ def sample_inputs_convert_color_space_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True
):
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space, copy=False)
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space)


@pil_reference_wrapper
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space, copy=True):
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space):
color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode)
if color_space_pil != old_color_space:
raise pytest.UsageError(
f"Converting the tensor image into an PIL image changed the colorspace "
f"from {old_color_space} to {color_space_pil}"
)

return F.convert_color_space_image_pil(image_pil, color_space=new_color_space, copy=copy)
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space)


def reference_inputs_convert_color_space_image_tensor():
Expand Down
20 changes: 9 additions & 11 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
device=bbox.device,
)
return (
convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
),
convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format),
(height, width),
)

Expand Down Expand Up @@ -733,14 +731,16 @@ def _compute_expected_bbox(bbox, padding_):

bbox_format = bbox.format
bbox_dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox, old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY)
bbox = (
bbox.clone()
if bbox_format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY)
)

bbox[0::2] += pad_left
bbox[1::2] += pad_up

bbox = convert_format_bounding_box(
bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
)
bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format)
if bbox.dtype != bbox_dtype:
# Temporary cast to original dtype
# e.g. float32 -> int
Expand Down Expand Up @@ -840,9 +840,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
dtype=bbox.dtype,
device=bbox.device,
)
return convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
)
return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format)

spatial_size = (32, 38)

Expand Down Expand Up @@ -903,7 +901,7 @@ def _compute_expected_bbox(bbox, output_size_):
dtype=bbox.dtype,
device=bbox.device,
)
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False)
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)

for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device)
Expand Down
12 changes: 0 additions & 12 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,6 @@ def spatial_size(self) -> Tuple[int, int]:
def num_channels(self) -> int:
return self.shape[-3]

def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())

return Image.wrap_like(
self,
self._F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)

def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self)
return Image.wrap_like(self, output)
Expand Down
12 changes: 0 additions & 12 deletions torchvision/prototype/features/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,6 @@ def num_channels(self) -> int:
def num_frames(self) -> int:
return self.shape[-4]

def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())

return Video.wrap_like(
self,
self._F.convert_color_space_video(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)

def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self)
return Video.wrap_like(self, output)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _copy_paste(
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_box(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])

Expand Down
23 changes: 12 additions & 11 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
continue

# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
xyxy_bboxes = (
bboxes.clone()
if bboxes.format == features.BoundingBoxFormat.XYXY
else F.convert_format_bounding_box(bboxes, bboxes.format, features.BoundingBoxFormat.XYXY)
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
Expand Down Expand Up @@ -801,22 +803,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
top = int(offset_height * r)
left = int(offset_width * r)

bounding_boxes: Optional[torch.Tensor]
try:
bounding_boxes = query_bounding_box(flat_inputs)
except ValueError:
bounding_boxes = None

if needs_crop and bounding_boxes is not None:
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width)
format = bounding_boxes.format
bounding_boxes, spatial_size = F.crop_bounding_box(
bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width
)
bounding_boxes = features.BoundingBox.wrap_like(
bounding_boxes,
F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, spatial_size=bounding_boxes.spatial_size
),
)
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size)
height_and_width = F.convert_format_bounding_box(
bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None
Expand Down
7 changes: 1 addition & 6 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
self,
color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
copy: bool = True,
) -> None:
super().__init__()

Expand All @@ -56,14 +55,10 @@ def __init__(
old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space

self.copy = copy

def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)


class ClampBoundingBoxes(Transform):
Expand Down
65 changes: 40 additions & 25 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ def horizontal_flip_bounding_box(
) -> torch.Tensor:
shape = bounding_box.shape

bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4)

bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]

return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(shape)


Expand Down Expand Up @@ -73,14 +77,18 @@ def vertical_flip_bounding_box(
) -> torch.Tensor:
shape = bounding_box.shape

bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4)

bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]

return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(shape)


Expand Down Expand Up @@ -394,16 +402,21 @@ def affine_bounding_box(
center: Optional[List[float]] = None,
) -> torch.Tensor:
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY

# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4)

out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)

# out_bboxes should be of shape [N boxes, 4]

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(original_shape)


Expand Down Expand Up @@ -583,8 +596,8 @@ def rotate_bounding_box(
center = None

original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4)

out_bboxes, spatial_size = _affine_bounding_box_xyxy(
Expand All @@ -599,9 +612,9 @@ def rotate_bounding_box(
)

return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).reshape(original_shape),
convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape(
original_shape
),
spatial_size,
)

Expand Down Expand Up @@ -818,18 +831,20 @@ def crop_bounding_box(
height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
)

# Crop or implicit pad if left and/or top have negative values:
bounding_box[..., 0::2] -= left
bounding_box[..., 1::2] -= top

return (
convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
),
convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format),
(height, width),
)

Expand Down Expand Up @@ -896,8 +911,8 @@ def perspective_bounding_box(
raise ValueError("Argument perspective_coeffs should have 8 float values")

original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4)

dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
Expand Down Expand Up @@ -967,7 +982,7 @@ def perspective_bounding_box(
# out_bboxes should be of shape [N boxes, 4]

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(original_shape)


Expand Down Expand Up @@ -1061,8 +1076,8 @@ def elastic_bounding_box(
displacement = displacement.to(bounding_box.device)

original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4)

# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
Expand All @@ -1088,7 +1103,7 @@ def elastic_bounding_box(
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(original_shape)


Expand Down
Loading