Skip to content

RFC: Make prototype F jit-scriptable again? #6553

Closed
@datumbox

Description

@datumbox

🚀 The feature

The mid-layer kernels F of the new Transforms API are currently not JIT-scriptable because JIT doesn't support Tensor subclassing. This is a major BC-breaking change that potentially could be avoided by refactoring our mid-level kernels.

Here is one way of how this could be achieved:

def kernel(x: torch.Tensor) -> torch.Tensor:
    if isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)):
        return _FT.kernel(x)  # image tensor dispatch
    elif isinstance(x, features._Feature):
        return x.kernel()  # _Feature dispatch
    else:
        return _FP.kernel(x)  # PIL dispatch

Effectively the above makes the kernel() operate exactly like before on the stable F: if a torch.Tensor kernel is provided then it is dispatched directly to the _FT kernel. The assumption is that while JIT-scripting, only image tensors are allowed. As long as the _features.Image and the torch.Tensor kernels are identical (which is true in our implementation), this means that we won't have any discrepancies between the two types. The extension on the new tensor subclasses (Image, BBox, Mask, Label etc) is only available in Python mode.

Unfortunately it's not possible to check the exact type of Tensor on run-time and throw the appropriate error if a user passes by accident a BBox or Label during JIT-scripting. But for these types, other checks (such as the dimension checks) will raise errors.

Proof of Concept

The below implementation is a proof of concept. We intentionally make the kernels simulate different functionality in order to test that the dispatches work as expected.

import numpy as np
import torch
from PIL import Image
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import to_pil_image


# Dummy low-level kernels, we only offer for image_* but we could have added for all types
def invertedmeanpix_image_tensor(x: torch.Tensor) -> float:
    return -_FT.invert(x).float().mean().item()  # return negative mean to simulate different functionality


@torch.jit.unused
def invertedmeanpix_image_pil(x: Image.Image) -> float:
    return np.mean(_FP.invert(x)) / 255.0  # return 0-1 scaled mean to simulate different functionality


# Dummy mid-level which fakes input type to make kernel JIT-scriptable
def invertedmeanpix(x: torch.Tensor) -> float:
    if isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)):
        return invertedmeanpix_image_tensor(x)
    elif isinstance(x, features._Feature):
        return x.invert().float().mean().item()  # estimate standard mean to differentiate from the other 2 cases
    else:
        return invertedmeanpix_image_pil(x)


# JIT scripted kernel
invertedmeanpix_scripted = torch.jit.script(invertedmeanpix)
print(invertedmeanpix_scripted.code)

# Create dummy data
image_tensor = torch.randint(0, 256, (3, 30, 20), dtype=torch.uint8)
image_feature = features.Image(image_tensor)
image_pil = to_pil_image(image_tensor)

bbox_feature = features.BoundingBox(
    torch.tensor([[12, 13, 19, 18], [1, 15, 8, 19]]),
    format=features.BoundingBoxFormat.XYXY,
    image_size=image_tensor.shape[:-2],
)

label_data = torch.tensor([1, 2])
label_feature = features.Label(label_data)
ohelabel_data = torch.nn.functional.one_hot(label_data, 3)
ohelabel_feature = features.OneHotLabel(ohelabel_data)

segmask_data = torch.zeros(1, 30, 20)
segmask_data[13:19, 12:18] = 1
segmask_data[15:19, 1:8] = 2
segmask_feature = features.SegmentationMask(segmask_data)

detmask_data = torch.zeros(2, 30, 20)
detmask_data[0, 13:19, 12:18] = 1
detmask_data[1, 15:19, 1:8] = 1
detmask_feature = features.SegmentationMask(detmask_data)


# Assertions
value = invertedmeanpix(image_tensor)
assert value < 0
torch.testing.assert_close(invertedmeanpix(image_feature), -value)
torch.testing.assert_close(invertedmeanpix(image_pil), -value / 255.0)
torch.testing.assert_close(invertedmeanpix_scripted(image_tensor), value)
torch.testing.assert_close(invertedmeanpix(label_feature), label_data.float().mean().item())
torch.testing.assert_close(invertedmeanpix(ohelabel_feature), ohelabel_data.float().mean().item())
torch.testing.assert_close(invertedmeanpix(segmask_feature), segmask_data.mean().item())
torch.testing.assert_close(invertedmeanpix(detmask_feature), detmask_data.mean().item())
print("OK")

Output:

def invertedmeanpix(x: Tensor) -> float:
  _0 = __torch__.invertedmeanpix_image_tensor
  x0 = unchecked_cast(Tensor, x)
  x1 = unchecked_cast(Tensor, x0)
  return _0(x1, )

OK

cc @vfdev-5 @pmeier

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions