Skip to content

Commit ad7a958

Browse files
Deprecate method to_torchscript (#21397)
* deprecate method * deprecate method * add deprecation to tests * remove example from readme * remove example from readme * changelog * remove readme changes --------- Co-authored-by: Deependu <[email protected]>
1 parent f3f6605 commit ad7a958

File tree

4 files changed

+58
-15
lines changed

4 files changed

+58
-15
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
-
1818

19+
### Deprecated
20+
21+
- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397))
22+
1923
### Removed
2024

2125
- Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398))

src/lightning/pytorch/core/module.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from lightning.pytorch.utilities.exceptions import MisconfigurationException
6565
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1
6666
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
67-
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
67+
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_deprecation, rank_zero_warn
6868
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
6969
from lightning.pytorch.utilities.types import (
7070
_METRIC,
@@ -1498,6 +1498,11 @@ def to_torchscript(
14981498
scripted you should override this method. In case you want to return multiple modules, we recommend using a
14991499
dictionary.
15001500
1501+
.. deprecated::
1502+
``LightningModule.to_torchscript`` has been deprecated in v2.7 and will be removed in v2.8.
1503+
TorchScript is deprecated in PyTorch. Use ``torch.export.export()`` for model exporting instead.
1504+
See https://pytorch.org/docs/stable/export.html for more information.
1505+
15011506
Args:
15021507
file_path: Path where to save the torchscript. Default: None (no file saved).
15031508
method: Whether to use TorchScript's script or trace method. Default: 'script'
@@ -1536,6 +1541,11 @@ def forward(self, x):
15361541
defined or not.
15371542
15381543
"""
1544+
rank_zero_deprecation(
1545+
"`LightningModule.to_torchscript` has been deprecated in v2.7 and will be removed in v2.8. "
1546+
"TorchScript is deprecated in PyTorch. Use `torch.export.export()` for model exporting instead. "
1547+
"See https://pytorch.org/docs/stable/export.html for more information."
1548+
)
15391549
mode = self.training
15401550

15411551
if method == "script":

tests/tests_pytorch/helpers/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def test_models(tmp_path, data_class, model_class):
4646
if dm is not None:
4747
trainer.test(model, datamodule=dm)
4848

49-
model.to_torchscript()
49+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
50+
model.to_torchscript()
5051
if data_class:
5152
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
5253

tests/tests_pytorch/models/test_torchscript.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from lightning.fabric.utilities.cloud_io import get_filesystem
2323
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
24+
from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning
2425
from lightning.pytorch.core.module import LightningModule
2526
from lightning.pytorch.demos.boring_classes import BoringModel
2627
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
@@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass):
3637
if isinstance(model, BoringModel):
3738
model.example_input_array = torch.randn(5, 32)
3839

39-
script = model.to_torchscript()
40+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
41+
script = model.to_torchscript()
4042
assert isinstance(script, torch.jit.ScriptModule)
4143

4244
model.eval()
@@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass):
5961
if isinstance(model, BoringModel):
6062
model.example_input_array = torch.randn(5, 32)
6163

62-
script = model.to_torchscript(method="trace")
64+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
65+
script = model.to_torchscript(method="trace")
6366
assert isinstance(script, torch.jit.ScriptModule)
6467

6568
model.eval()
@@ -74,7 +77,8 @@ def test_torchscript_input_output_trace():
7477
"""Test that traced LightningModule forward works with example_inputs."""
7578
model = BoringModel()
7679
example_inputs = torch.randn(1, 32)
77-
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
80+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
81+
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
7882
assert isinstance(script, torch.jit.ScriptModule)
7983

8084
model.eval()
@@ -99,7 +103,8 @@ def test_torchscript_device(device_str):
99103
model = BoringModel().to(device)
100104
model.example_input_array = torch.randn(5, 32)
101105

102-
script = model.to_torchscript()
106+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
107+
script = model.to_torchscript()
103108
assert next(script.parameters()).device == device
104109
script_output = script(model.example_input_array.to(device))
105110
assert script_output.device == device
@@ -121,19 +126,22 @@ def test_torchscript_device_with_check_inputs(device_str):
121126

122127
check_inputs = torch.rand(5, 32)
123128

124-
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
129+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
130+
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
125131
assert isinstance(script, torch.jit.ScriptModule)
126132

127133

128134
def test_torchscript_retain_training_state():
129135
"""Test that torchscript export does not alter the training mode of original model."""
130136
model = BoringModel()
131137
model.train(True)
132-
script = model.to_torchscript()
138+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
139+
script = model.to_torchscript()
133140
assert model.training
134141
assert not script.training
135142
model.train(False)
136-
_ = model.to_torchscript()
143+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
144+
_ = model.to_torchscript()
137145
assert not model.training
138146
assert not script.training
139147

@@ -142,7 +150,8 @@ def test_torchscript_retain_training_state():
142150
def test_torchscript_properties(modelclass):
143151
"""Test that scripted LightningModule has unnecessary methods removed."""
144152
model = modelclass()
145-
script = model.to_torchscript()
153+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
154+
script = model.to_torchscript()
146155
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
147156
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
148157
assert not callable(getattr(script, "training_step", None))
@@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass):
153162
"""Test that scripted LightningModule is correctly saved and can be loaded."""
154163
model = modelclass()
155164
output_file = str(tmp_path / "model.pt")
156-
script = model.to_torchscript(file_path=output_file)
165+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
166+
script = model.to_torchscript(file_path=output_file)
157167
loaded_script = torch.jit.load(output_file)
158168
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
159169

@@ -170,7 +180,8 @@ class DummyFileSystem(LocalFileSystem): ...
170180

171181
model = modelclass()
172182
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt")
173-
script = model.to_torchscript(file_path=output_file)
183+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
184+
script = model.to_torchscript(file_path=output_file)
174185

175186
fs = get_filesystem(output_file)
176187
with fs.open(output_file, "rb") as f:
@@ -184,7 +195,10 @@ def test_torchcript_invalid_method():
184195
model = BoringModel()
185196
model.train(True)
186197

187-
with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
198+
with (
199+
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
200+
pytest.raises(ValueError, match="only supports 'script' or 'trace'"),
201+
):
188202
model.to_torchscript(method="temp")
189203

190204

@@ -193,7 +207,10 @@ def test_torchscript_with_no_input():
193207
model = BoringModel()
194208
model.example_input_array = None
195209

196-
with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"):
210+
with (
211+
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
212+
pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"),
213+
):
197214
model.to_torchscript(method="trace")
198215

199216

@@ -224,6 +241,17 @@ def forward(self, inputs):
224241

225242
lm = Parent()
226243
assert not lm._jit_is_scripting
227-
script = lm.to_torchscript(method="script")
244+
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
245+
script = lm.to_torchscript(method="script")
228246
assert not lm._jit_is_scripting
229247
assert isinstance(script, torch.jit.RecursiveScriptModule)
248+
249+
250+
def test_to_torchscript_deprecation():
251+
"""Test that to_torchscript raises a deprecation warning."""
252+
model = BoringModel()
253+
model.example_input_array = torch.randn(5, 32)
254+
255+
with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"):
256+
script = model.to_torchscript()
257+
assert isinstance(script, torch.jit.ScriptModule)

0 commit comments

Comments
 (0)