Skip to content

[POC] Potential ways for making Transforms V2 classes JIT-scriptable #6711

Closed
@datumbox

Description

@datumbox

🚀 The feature

Note: This is an exploratory proof-of-concept to discuss potential workarounds for offering limited support of JIT in our Transforms V2 Classes. I am NOT advocating for following this approach. I'm hoping we can kick off the discussion for other alternative and simpler approaches.

Currently the Transforms V2 classes are not JIT-scriptable. This breaks BC and will make the rollout of the new API harder. Here are some of the choices that are incompatible with JIT:

  1. We wanted to support arbitrary number of inputs.
  2. We rely on Tensor Subclassing to do the dispatch to the right kernel.
  3. We use real typing information on the inputs which often includes types that are not scriptable.
  4. We opted for using more Pythonic idioms (such as for ... else)

Points 3 & 4 could be addressed by (painful) refactoring, nevertheless points 1 & 2 are our main blockers.

To ensure our users can still do inference using JIT, we offer presets/transforms attached to each model weights. Those will remain JIT-scriptable. In addition, we applied a workaround (#6553) to maintain the F dispatcher JIT-scriptable for plain Tensors. Hopefully these mitigations will help most users migrate easier to the new API.

But what if they don't? Many downstream users might want to continue relying on transforms such as Resize, CenterCrop, Pad etc for inference. In that case, one option could be to offer JIT-scriptable alternatives that work only for pure tensors. Another alternative is to write a utility that can modify the existing implementations on-the-fly to update key functions and make them JIT-scriptable.

Motivation, pitch

This is a proof-of-concept of how such a utility can work. It only supports a handful of transforms (due to points 3 & 4 from above) but it can be extended to support more.

There are 2 approaches show-cased below:

  1. We use ast to replace on-the-fly problematic idioms from the Transform classes. Since JIT also uses ast internally, we need to make the updated code available to JIT during scripting.
  2. We replace the forward() to remove the packing/unpacking of arbitrary number of inputs. We also hardcode plain tensors as the only accepted input type.
import ast
import inspect
import tempfile
import torch
import types

from torchvision import transforms as V1
from torchvision.prototype import transforms as V2
from torchvision.prototype import features


class JITWrapper(torch.nn.Module):

    def __init__(self, cls, *args, **kwargs):
        super().__init__()
        # Patch _transform types, can be avoided by defining directly JIT-scriptable types
        code = inspect.getsource(cls)
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                node.name = f"{cls.__name__}JIT"
            elif isinstance(node, ast.FunctionDef):
                if node.name == "_transform":
                    node.args.args[1].annotation.id = "features.InputTypeJIT"
                    node.returns.id = "features.InputTypeJIT"
        source = ast.unparse(tree)

        # Writes the source on a temp file. Needed for JIT's inspect calls to work properly.
        with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp:
            temp.write(source)
            filename = temp.name

        # Compiles the new modified Class from source
        code = compile(source, filename, "exec")
        mod = {}
        exec(code, vars(inspect.getmodule(cls)), mod)
        cls = next(iter(mod.values()))

        # initialize transform
        transform = cls(*args, **kwargs)

        # Patch forward
        if hasattr(transform, "_jit_forward"):
            # Use the one defined in the class if available
            transform.forward = transform._jit_forward
        else:
            # Use the default implementation
            setattr(transform, "forward", types.MethodType(JITWrapper.__default_jit_forward, transform))

        self._wrapped = transform

    @staticmethod
    def __default_jit_forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        params = self._get_params(inputs)
        result = self._transform(inputs, params)
        return result

    def forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        return self._wrapped.forward(inputs)


def assert_jit_scriptable(t, inpt):
    torch.manual_seed(0)
    eager_out = t(inpt)

    t_scripted = torch.jit.script(t)
    with tempfile.NamedTemporaryFile(delete=False) as temp:
        t_scripted.save(temp.name)
        t_scripted = torch.jit.load(temp.name)

    torch.manual_seed(0)
    script_out = t_scripted(inpt)
    torch.testing.assert_close(eager_out, script_out)
    return script_out


img = torch.randn((1, 3, 224, 224))

t = V1.Resize((32, 32))
out1 = assert_jit_scriptable(t, img)
print("T1: OK")

t = JITWrapper(V2.Resize, (32, 32))
out2 = assert_jit_scriptable(t, img)
print("T2: OK")

torch.testing.assert_close(out1, out2)
print("T1 == T2: OK")

The above works on our latest main without modifications:

T1: OK
T2: OK
T1 == T2: OK

This approach can currently only support a handful of simple Transforms, that don't require overwriting the forward() and that contain most of their logic inside their _get_params() and _transform() methods. Many such simple transforms are still not supported because they inherit from _RandomApplyTransform which does the random call in its forward (this could be refactored to move to _get_params()). The rest of the existing inference transforms can be supported by addressing points 3 & 4 from above.

The above approach is very over-engineered, brittle and opaque because it tries to fix the JIT-scriptability issues without any modifications on the code-base for the selected example. If we accept minor refactoring on the existing classes, we can remove the ast logic. We could also avoid defining a default JIT-compatible forward by explicitly defining such a method on the original class when available. Here is one potential simplified version that would require changes on our current API:

class JITWrapper(torch.nn.Module):

    def __init__(self, transform: Transform):
        super().__init__()
        # Patch forward
        if hasattr(transform, "_jit_forward"):
            # Use the one defined in the class if available, should reuse `_get_params` and `_transform`
            transform.forward = transform._jit_forward
        else:
            raise Exception(f"The {cls.__name__} transform doesn't support scripting")

        self._wrapped = transform

    def forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        return self._wrapped.forward(inputs)


class Resize(Transform):
    # __init__ and _get_params() goes here

    def _transform(self, inpt: features.InputTypeJIT, params: Dict[str, Any]) -> features.InputTypeJIT:
        # we changed the types. Everything else in the method should be the same

    def _jit_forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        params = self._get_params(inputs)
        result = self._transform(inputs, params)
        return result

Alternatives

There are several other alternatives we could follow. One of them could be to offer JIT-scriptable versions for a limited number of Transforms that are commonly used during inference. Another one could be to make some of our transforms FX-traceable instead of JIT-scriptable. Though not all classes can become traceable (because their behaviour branches based on the input), considering making them compatible will future proof us for PyTorch 2.

Additional context

No response

cc @vfdev-5 @bjuncek @pmeier

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions