-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
🐛 Bug
My model code makes extensive use of tensors created during the forward pass. As part of this, I have to specify the device associated with each tensor to avoid conflicts. I want to export my model in TorchScript for deployment elsewhere, but the DeviceDtypeModuleMixin properties do not appear to be compatible with TorchScript.
While this appears to affect the dtype
property too, for my use case, using self.device
in a LightningModule prevents successful TorchScript export.
To Reproduce
import pytorch_lightning as pl
import torch
class TestModel(pl.LightningModule):
def forward(self):
return torch.zeros((2, 2), device=self.device)
model = TestModel()
model.to_torchscript()
results in the following error:
RuntimeError:
Module 'TestModel' has no attribute 'device' :
File "<snip>/test.py", line 7
def forward(self):
return torch.zeros((2, 2), device=self.device)
~~~~~~~~~~~ <--- HERE
I am not sure how this can be fixed without changing the way DeviceDtypeModuleMixin works. It relies on the @property
decorator, which uses descriptors. TorchScript is not compatible with descriptors (3.3.2.2, 3.3.2.3).
Additionally, the device
property returns a Union[str, torch.device]
, so even if it worked as a property in TorchScript, the return value of the function is incompatible with the type expected by the device
keyword option (Optional[Device]
). There appears to have been some discussion of this in this issue, but I'm not sure how it was resolved.
I am happy to try contributing a solution, but I wanted to discuss here first. A possible solution seems to involve eliminating the @property
decorator and save the current state of the device as a plain device
attribute on the model object. The actions currently in the device()
function body should be performed as the device
attribute is being set rather than on the way out.
Expected behavior
The code shown above is successfully saved as TorchScript.
Environment
- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule
- PyTorch Lightning Version (e.g., 1.5.0): 1.6.5
- PyTorch Version (e.g., 1.10): 1.12.0
- Python version (e.g., 3.9): 3.9.13
- OS (e.g., Linux): Linux
- CUDA/cuDNN version: 11.6
- How you installed PyTorch (
conda
,pip
, source): pip
cc @carmocca @justusschock @awaelchli @Borda @ananthsub @ninginthecloud @jjenniferdai @rohitgr7