Description
🚀 The feature
Note: This is an exploratory proof-of-concept to discuss potential workarounds for offering limited support of JIT in our Transforms V2 Classes. I am NOT advocating for following this approach. I'm hoping we can kick off the discussion for other alternative and simpler approaches.
Currently the Transforms V2 classes are not JIT-scriptable. This breaks BC and will make the rollout of the new API harder. Here are some of the choices that are incompatible with JIT:
- We wanted to support arbitrary number of inputs.
- We rely on Tensor Subclassing to do the dispatch to the right kernel.
- We use real typing information on the inputs which often includes types that are not scriptable.
- We opted for using more Pythonic idioms (such as
for ... else
)
Points 3 & 4 could be addressed by (painful) refactoring, nevertheless points 1 & 2 are our main blockers.
To ensure our users can still do inference using JIT, we offer presets/transforms attached to each model weights. Those will remain JIT-scriptable. In addition, we applied a workaround (#6553) to maintain the F
dispatcher JIT-scriptable for plain Tensors
. Hopefully these mitigations will help most users migrate easier to the new API.
But what if they don't? Many downstream users might want to continue relying on transforms such as Resize
, CenterCrop
, Pad
etc for inference. In that case, one option could be to offer JIT-scriptable alternatives that work only for pure tensors. Another alternative is to write a utility that can modify the existing implementations on-the-fly to update key functions and make them JIT-scriptable.
Motivation, pitch
This is a proof-of-concept of how such a utility can work. It only supports a handful of transforms (due to points 3 & 4 from above) but it can be extended to support more.
There are 2 approaches show-cased below:
- We use
ast
to replace on-the-fly problematic idioms from the Transform classes. Since JIT also usesast
internally, we need to make the updated code available to JIT during scripting. - We replace the
forward()
to remove the packing/unpacking of arbitrary number of inputs. We also hardcode plain tensors as the only accepted input type.
import ast
import inspect
import tempfile
import torch
import types
from torchvision import transforms as V1
from torchvision.prototype import transforms as V2
from torchvision.prototype import features
class JITWrapper(torch.nn.Module):
def __init__(self, cls, *args, **kwargs):
super().__init__()
# Patch _transform types, can be avoided by defining directly JIT-scriptable types
code = inspect.getsource(cls)
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
node.name = f"{cls.__name__}JIT"
elif isinstance(node, ast.FunctionDef):
if node.name == "_transform":
node.args.args[1].annotation.id = "features.InputTypeJIT"
node.returns.id = "features.InputTypeJIT"
source = ast.unparse(tree)
# Writes the source on a temp file. Needed for JIT's inspect calls to work properly.
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp:
temp.write(source)
filename = temp.name
# Compiles the new modified Class from source
code = compile(source, filename, "exec")
mod = {}
exec(code, vars(inspect.getmodule(cls)), mod)
cls = next(iter(mod.values()))
# initialize transform
transform = cls(*args, **kwargs)
# Patch forward
if hasattr(transform, "_jit_forward"):
# Use the one defined in the class if available
transform.forward = transform._jit_forward
else:
# Use the default implementation
setattr(transform, "forward", types.MethodType(JITWrapper.__default_jit_forward, transform))
self._wrapped = transform
@staticmethod
def __default_jit_forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
params = self._get_params(inputs)
result = self._transform(inputs, params)
return result
def forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
return self._wrapped.forward(inputs)
def assert_jit_scriptable(t, inpt):
torch.manual_seed(0)
eager_out = t(inpt)
t_scripted = torch.jit.script(t)
with tempfile.NamedTemporaryFile(delete=False) as temp:
t_scripted.save(temp.name)
t_scripted = torch.jit.load(temp.name)
torch.manual_seed(0)
script_out = t_scripted(inpt)
torch.testing.assert_close(eager_out, script_out)
return script_out
img = torch.randn((1, 3, 224, 224))
t = V1.Resize((32, 32))
out1 = assert_jit_scriptable(t, img)
print("T1: OK")
t = JITWrapper(V2.Resize, (32, 32))
out2 = assert_jit_scriptable(t, img)
print("T2: OK")
torch.testing.assert_close(out1, out2)
print("T1 == T2: OK")
The above works on our latest main without modifications:
T1: OK
T2: OK
T1 == T2: OK
This approach can currently only support a handful of simple Transforms, that don't require overwriting the forward()
and that contain most of their logic inside their _get_params()
and _transform()
methods. Many such simple transforms are still not supported because they inherit from _RandomApplyTransform
which does the random call in its forward (this could be refactored to move to _get_params()
). The rest of the existing inference transforms can be supported by addressing points 3 & 4 from above.
The above approach is very over-engineered, brittle and opaque because it tries to fix the JIT-scriptability issues without any modifications on the code-base for the selected example. If we accept minor refactoring on the existing classes, we can remove the ast
logic. We could also avoid defining a default JIT-compatible forward by explicitly defining such a method on the original class when available. Here is one potential simplified version that would require changes on our current API:
class JITWrapper(torch.nn.Module):
def __init__(self, transform: Transform):
super().__init__()
# Patch forward
if hasattr(transform, "_jit_forward"):
# Use the one defined in the class if available, should reuse `_get_params` and `_transform`
transform.forward = transform._jit_forward
else:
raise Exception(f"The {cls.__name__} transform doesn't support scripting")
self._wrapped = transform
def forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
return self._wrapped.forward(inputs)
class Resize(Transform):
# __init__ and _get_params() goes here
def _transform(self, inpt: features.InputTypeJIT, params: Dict[str, Any]) -> features.InputTypeJIT:
# we changed the types. Everything else in the method should be the same
def _jit_forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
params = self._get_params(inputs)
result = self._transform(inputs, params)
return result
Alternatives
There are several other alternatives we could follow. One of them could be to offer JIT-scriptable versions for a limited number of Transforms that are commonly used during inference. Another one could be to make some of our transforms FX-traceable instead of JIT-scriptable. Though not all classes can become traceable (because their behaviour branches based on the input), considering making them compatible will future proof us for PyTorch 2.
Additional context
No response