Skip to content

TorchScript export incompatible with DeviceDtypeModuleMixin properties #13887

@mattm458

Description

@mattm458

🐛 Bug

My model code makes extensive use of tensors created during the forward pass. As part of this, I have to specify the device associated with each tensor to avoid conflicts. I want to export my model in TorchScript for deployment elsewhere, but the DeviceDtypeModuleMixin properties do not appear to be compatible with TorchScript.

While this appears to affect the dtype property too, for my use case, using self.device in a LightningModule prevents successful TorchScript export.

To Reproduce

import pytorch_lightning as pl
import torch


class TestModel(pl.LightningModule):
    def forward(self):
        return torch.zeros((2, 2), device=self.device)


model = TestModel()
model.to_torchscript()

results in the following error:

RuntimeError: 
Module 'TestModel' has no attribute 'device' :
  File "<snip>/test.py", line 7
    def forward(self):
        return torch.zeros((2, 2), device=self.device)
                                          ~~~~~~~~~~~ <--- HERE

I am not sure how this can be fixed without changing the way DeviceDtypeModuleMixin works. It relies on the @property decorator, which uses descriptors. TorchScript is not compatible with descriptors (3.3.2.2, 3.3.2.3).

Additionally, the device property returns a Union[str, torch.device], so even if it worked as a property in TorchScript, the return value of the function is incompatible with the type expected by the device keyword option (Optional[Device]). There appears to have been some discussion of this in this issue, but I'm not sure how it was resolved.

I am happy to try contributing a solution, but I wanted to discuss here first. A possible solution seems to involve eliminating the @property decorator and save the current state of the device as a plain device attribute on the model object. The actions currently in the device() function body should be performed as the device attribute is being set rather than on the way out.

Expected behavior

The code shown above is successfully saved as TorchScript.

Environment

  • Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule
  • PyTorch Lightning Version (e.g., 1.5.0): 1.6.5
  • PyTorch Version (e.g., 1.10): 1.12.0
  • Python version (e.g., 3.9): 3.9.13
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 11.6
  • How you installed PyTorch (conda, pip, source): pip

cc @carmocca @justusschock @awaelchli @Borda @ananthsub @ninginthecloud @jjenniferdai @rohitgr7

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinglightningmodulepl.LightningModulepriority: 2Low priority task

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions