Skip to content

Commit 4df546a

Browse files
Enable inference mode for evaluation (#12715)
* Enable inference mode for evaluation * better name * Update CHANGELOG.md Co-authored-by: Akihiro Nitta <[email protected]>
1 parent a758d90 commit 4df546a

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Enabled `torch.inference_mode` for evaluation and prediction ([#12715](https://github.com/PyTorchLightning/pytorch-lightning/pull/12715))
13+
14+
1215
- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993))
1316

17+
1418
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))
1519

1620

pytorch_lightning/trainer/trainer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import traceback
2020
import warnings
2121
from argparse import ArgumentParser, Namespace
22+
from contextlib import contextmanager
2223
from copy import deepcopy
2324
from datetime import timedelta
2425
from pathlib import Path
25-
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Type, Union
26+
from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union
2627
from weakref import proxy
2728

2829
import torch
30+
import torch.distributed as dist
2931
from packaging.version import Version
3032
from torch.optim import Optimizer
3133
from torch.utils.data import DataLoader
@@ -97,7 +99,7 @@
9799
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
98100
from pytorch_lightning.utilities.distributed import distributed_available
99101
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
100-
from pytorch_lightning.utilities.imports import _fault_tolerant_training
102+
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
101103
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
102104
from pytorch_lightning.utilities.model_helpers import is_overridden
103105
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
@@ -1316,7 +1318,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
13161318
# reset trainer on this loop and all child loops in case user connected a custom loop
13171319
self._evaluation_loop.trainer = self
13181320

1319-
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
1321+
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context():
13201322
eval_loop_results = self._evaluation_loop.run()
13211323

13221324
# remove the tensors from the eval results
@@ -1332,7 +1334,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
13321334
self.reset_predict_dataloader(self.lightning_module)
13331335
# reset trainer on this loop and all child loops in case user connected a custom loop
13341336
self.predict_loop.trainer = self
1335-
with torch.no_grad():
1337+
with _evaluation_context():
13361338
return self.predict_loop.run()
13371339

13381340
def _run_sanity_check(self) -> None:
@@ -2748,6 +2750,18 @@ def configure_optimizers(self):
27482750
return max_estimated_steps
27492751

27502752

2753+
@contextmanager
2754+
def _evaluation_context() -> Generator:
2755+
# inference mode is not supported with gloo backend (#9431)
2756+
context_manager_class = (
2757+
torch.inference_mode
2758+
if _TORCH_GREATER_EQUAL_1_9 and not (dist.is_initialized() and dist.get_backend() == "gloo")
2759+
else torch.no_grad
2760+
)
2761+
with context_manager_class():
2762+
yield
2763+
2764+
27512765
def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
27522766
if batches is None:
27532767
# batches is optional to know if the user passed a value so that we can show the above info messages only to the

0 commit comments

Comments
 (0)