Skip to content

Support torch.jit.script on LightningModules #1951

@neighthan

Description

@neighthan

🚀 Feature

There are a number of advantages to converting a model with TorchScript (e.g. static optimizations, better saving / loading, especially into non-Python environments for deployment). However, no LightningModules can be converted using torch.jit.script. Here's a simple example with the error produced (note that this works as-is if we inherit from nn.Module instead of pl.LightningModule):

import pytorch_lightning as pl
import torch

# class Model(nn.Module): # works fine
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(5, 10)
    
    def forward(self, x):
        return self.layer(x)

torch.jit.script(Model())
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-70-1fe19c1470da> in <module>
     10         return self.layer(x)
     11 
---> 12 torch.jit.script(Model())

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
   1259 
   1260     if isinstance(obj, torch.nn.Module):
-> 1261         return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
   1262 
   1263     qualified_name = _qualified_name(obj)

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    295     if share_types:
    296         # Look into the store of cached JIT types
--> 297         concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
    298     else:
    299         # Get a concrete type directly, without trying to re-use an existing JIT

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
    254             return nn_module._concrete_type
    255 
--> 256         concrete_type_builder = infer_concrete_type_builder(nn_module)
    257 
    258         nn_module_type = type(nn_module)

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in infer_concrete_type_builder(nn_module)
    133     # Constants annotated via `Final[T]` rather than being added to `__constants__`
    134     for name, ann in class_annotations.items():
--> 135         if torch._jit_internal.is_final(ann):
    136             constants_set.add(name)
    137 

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/_jit_internal.py in is_final(ann)
    681 
    682     def is_final(ann):
--> 683         return ann.__module__ == 'typing_extensions' and \
    684             (getattr(ann, '__origin__', None) is typing_extensions.Final)
    685 except ImportError:

AttributeError: 'ellipsis' object has no attribute '__module__'

Digging into this a little, we have

print(Model.__annotations__)
# {'_device': Ellipsis, '_dtype': typing.Union[str, torch.dtype]}

and the _device annotation comes from DeviceDtypeModuleMixin (one of the super-classes of LightningModule). Here's the relevant snippet:

class DeviceDtypeModuleMixin(torch.nn.Module):
    _device: ...

This seems to be the only issue because this code works:

import pytorch_lightning as pl
import torch

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(5, 10)
    
    def forward(self, x):
        return self.layer(x)

# Model.__annotations__ = {} # this works too but doesn't seem as nice
Model.__annotations__["_device"] = torch.device
torch.jit.script(Model())

However, if I try to set the annotation to typing.Union[str, torch.device] (which seems to be the true type based on this line), then I get ValueError: Unknown type annotation: 'typing.Union[str, torch.device]' in torch.jit.script`.

Is the str type for _device actually used? I don't see that anywhere, and I actually do see at least one place where there would be an error if self.device returned a string (here. I'll just go ahead and submit a PR to update the annotations, but feel free to comment here or on the PR if there's something I'm missing about the type annotations here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions