Description
🚀 The feature
The mid-layer kernels F
of the new Transforms API are currently not JIT-scriptable because JIT doesn't support Tensor subclassing. This is a major BC-breaking change that potentially could be avoided by refactoring our mid-level kernels.
Here is one way of how this could be achieved:
def kernel(x: torch.Tensor) -> torch.Tensor:
if isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)):
return _FT.kernel(x) # image tensor dispatch
elif isinstance(x, features._Feature):
return x.kernel() # _Feature dispatch
else:
return _FP.kernel(x) # PIL dispatch
Effectively the above makes the kernel()
operate exactly like before on the stable F
: if a torch.Tensor
kernel is provided then it is dispatched directly to the _FT
kernel. The assumption is that while JIT-scripting, only image tensors are allowed. As long as the _features.Image
and the torch.Tensor
kernels are identical (which is true in our implementation), this means that we won't have any discrepancies between the two types. The extension on the new tensor subclasses (Image, BBox, Mask, Label etc) is only available in Python mode.
Unfortunately it's not possible to check the exact type of Tensor on run-time and throw the appropriate error if a user passes by accident a BBox or Label during JIT-scripting. But for these types, other checks (such as the dimension checks) will raise errors.
Proof of Concept
The below implementation is a proof of concept. We intentionally make the kernels simulate different functionality in order to test that the dispatches work as expected.
import numpy as np
import torch
from PIL import Image
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import to_pil_image
# Dummy low-level kernels, we only offer for image_* but we could have added for all types
def invertedmeanpix_image_tensor(x: torch.Tensor) -> float:
return -_FT.invert(x).float().mean().item() # return negative mean to simulate different functionality
@torch.jit.unused
def invertedmeanpix_image_pil(x: Image.Image) -> float:
return np.mean(_FP.invert(x)) / 255.0 # return 0-1 scaled mean to simulate different functionality
# Dummy mid-level which fakes input type to make kernel JIT-scriptable
def invertedmeanpix(x: torch.Tensor) -> float:
if isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)):
return invertedmeanpix_image_tensor(x)
elif isinstance(x, features._Feature):
return x.invert().float().mean().item() # estimate standard mean to differentiate from the other 2 cases
else:
return invertedmeanpix_image_pil(x)
# JIT scripted kernel
invertedmeanpix_scripted = torch.jit.script(invertedmeanpix)
print(invertedmeanpix_scripted.code)
# Create dummy data
image_tensor = torch.randint(0, 256, (3, 30, 20), dtype=torch.uint8)
image_feature = features.Image(image_tensor)
image_pil = to_pil_image(image_tensor)
bbox_feature = features.BoundingBox(
torch.tensor([[12, 13, 19, 18], [1, 15, 8, 19]]),
format=features.BoundingBoxFormat.XYXY,
image_size=image_tensor.shape[:-2],
)
label_data = torch.tensor([1, 2])
label_feature = features.Label(label_data)
ohelabel_data = torch.nn.functional.one_hot(label_data, 3)
ohelabel_feature = features.OneHotLabel(ohelabel_data)
segmask_data = torch.zeros(1, 30, 20)
segmask_data[13:19, 12:18] = 1
segmask_data[15:19, 1:8] = 2
segmask_feature = features.SegmentationMask(segmask_data)
detmask_data = torch.zeros(2, 30, 20)
detmask_data[0, 13:19, 12:18] = 1
detmask_data[1, 15:19, 1:8] = 1
detmask_feature = features.SegmentationMask(detmask_data)
# Assertions
value = invertedmeanpix(image_tensor)
assert value < 0
torch.testing.assert_close(invertedmeanpix(image_feature), -value)
torch.testing.assert_close(invertedmeanpix(image_pil), -value / 255.0)
torch.testing.assert_close(invertedmeanpix_scripted(image_tensor), value)
torch.testing.assert_close(invertedmeanpix(label_feature), label_data.float().mean().item())
torch.testing.assert_close(invertedmeanpix(ohelabel_feature), ohelabel_data.float().mean().item())
torch.testing.assert_close(invertedmeanpix(segmask_feature), segmask_data.mean().item())
torch.testing.assert_close(invertedmeanpix(detmask_feature), detmask_data.mean().item())
print("OK")
Output:
def invertedmeanpix(x: Tensor) -> float:
_0 = __torch__.invertedmeanpix_image_tensor
x0 = unchecked_cast(Tensor, x)
x1 = unchecked_cast(Tensor, x0)
return _0(x1, )
OK