diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index a9917a80e7a..47860451774 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -230,8 +230,10 @@ def convert_format_bounding_box( elif isinstance(inpt, datapoints.BoundingBox): if old_format is not None: raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.") - output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace) - return datapoints.BoundingBox.wrap_like(inpt, output) + output = _convert_format_bounding_box( + inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace + ) + return datapoints.BoundingBox.wrap_like(inpt, output, format=new_format) else: raise TypeError( f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." @@ -266,7 +268,7 @@ def clamp_bounding_box( elif isinstance(inpt, datapoints.BoundingBox): if format is not None or spatial_size is not None: raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") - output = _clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size) + output = _clamp_bounding_box(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size) return datapoints.BoundingBox.wrap_like(inpt, output) else: raise TypeError(