-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
There are a number of advantages to converting a model with TorchScript (e.g. static optimizations, better saving / loading, especially into non-Python environments for deployment). However, no LightningModule
s can be converted using torch.jit.script
. Here's a simple example with the error produced (note that this works as-is if we inherit from nn.Module
instead of pl.LightningModule
):
import pytorch_lightning as pl
import torch
# class Model(nn.Module): # works fine
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(5, 10)
def forward(self, x):
return self.layer(x)
torch.jit.script(Model())
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-70-1fe19c1470da> in <module>
10 return self.layer(x)
11
---> 12 torch.jit.script(Model())
~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
1259
1260 if isinstance(obj, torch.nn.Module):
-> 1261 return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
1262
1263 qualified_name = _qualified_name(obj)
~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
295 if share_types:
296 # Look into the store of cached JIT types
--> 297 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
298 else:
299 # Get a concrete type directly, without trying to re-use an existing JIT
~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
254 return nn_module._concrete_type
255
--> 256 concrete_type_builder = infer_concrete_type_builder(nn_module)
257
258 nn_module_type = type(nn_module)
~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in infer_concrete_type_builder(nn_module)
133 # Constants annotated via `Final[T]` rather than being added to `__constants__`
134 for name, ann in class_annotations.items():
--> 135 if torch._jit_internal.is_final(ann):
136 constants_set.add(name)
137
~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/_jit_internal.py in is_final(ann)
681
682 def is_final(ann):
--> 683 return ann.__module__ == 'typing_extensions' and \
684 (getattr(ann, '__origin__', None) is typing_extensions.Final)
685 except ImportError:
AttributeError: 'ellipsis' object has no attribute '__module__'
Digging into this a little, we have
print(Model.__annotations__)
# {'_device': Ellipsis, '_dtype': typing.Union[str, torch.dtype]}
and the _device
annotation comes from DeviceDtypeModuleMixin
(one of the super-classes of LightningModule
). Here's the relevant snippet:
class DeviceDtypeModuleMixin(torch.nn.Module):
_device: ...
This seems to be the only issue because this code works:
import pytorch_lightning as pl
import torch
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(5, 10)
def forward(self, x):
return self.layer(x)
# Model.__annotations__ = {} # this works too but doesn't seem as nice
Model.__annotations__["_device"] = torch.device
torch.jit.script(Model())
However, if I try to set the annotation to typing.Union[str, torch.device]
(which seems to be the true type based on this line), then I get ValueError: Unknown type annotation: 'typing.Union[str, torch.device]' in
torch.jit.script`.
Is the str
type for _device
actually used? I don't see that anywhere, and I actually do see at least one place where there would be an error if self.device
returned a string (here. I'll just go ahead and submit a PR to update the annotations, but feel free to comment here or on the PR if there's something I'm missing about the type annotations here.