-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Enable inference mode for evaluation #12715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3208a59
5392d63
e075bc2
e306750
3317faa
a2ec545
f67510c
fbcde50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,15 @@ | |
import traceback | ||
import warnings | ||
from argparse import ArgumentParser, Namespace | ||
from contextlib import contextmanager | ||
from copy import deepcopy | ||
from datetime import timedelta | ||
from pathlib import Path | ||
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Type, Union | ||
from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union | ||
from weakref import proxy | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from packaging.version import Version | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
@@ -97,7 +99,7 @@ | |
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks | ||
from pytorch_lightning.utilities.distributed import distributed_available | ||
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException | ||
from pytorch_lightning.utilities.imports import _fault_tolerant_training | ||
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9 | ||
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module | ||
from pytorch_lightning.utilities.model_helpers import is_overridden | ||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn | ||
|
@@ -1316,7 +1318,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: | |
# reset trainer on this loop and all child loops in case user connected a custom loop | ||
self._evaluation_loop.trainer = self | ||
|
||
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): | ||
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(): | ||
eval_loop_results = self._evaluation_loop.run() | ||
|
||
# remove the tensors from the eval results | ||
|
@@ -1332,7 +1334,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: | |
self.reset_predict_dataloader(self.lightning_module) | ||
# reset trainer on this loop and all child loops in case user connected a custom loop | ||
self.predict_loop.trainer = self | ||
with torch.no_grad(): | ||
with _evaluation_context(): | ||
return self.predict_loop.run() | ||
|
||
def _run_sanity_check(self) -> None: | ||
|
@@ -2748,6 +2750,18 @@ def configure_optimizers(self): | |
return max_estimated_steps | ||
|
||
|
||
@contextmanager | ||
def _evaluation_context() -> Generator: | ||
# inference mode is not supported with gloo backend (#9431) | ||
context_manager_class = ( | ||
torch.inference_mode | ||
if _TORCH_GREATER_EQUAL_1_9 and not (dist.is_initialized() and dist.get_backend() == "gloo") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't find any information in docs or anywhere about inference_mode not being compatible with the gloo backend. A comment here in the code would probably be appropriate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for adding the comment. i'm definitely not satisfied still, and will investigate. this is very sus, especially since I cannot find any open or closed issue ticket on the pytorch github. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't find it either. Let me open an issue on PT GitHub. |
||
else torch.no_grad | ||
) | ||
with context_manager_class(): | ||
yield | ||
|
||
|
||
def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: | ||
if batches is None: | ||
# batches is optional to know if the user passed a value so that we can show the above info messages only to the | ||
|
Uh oh!
There was an error while loading. Please reload this page.