-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
🚀 Feature
pytorch 1.9 introduced a new decorator/context manager specifically for model inference torch.inference_mode()
. Like torch.no_grad()
it disables gradient tracking for faster forward passes through a model and is, as the name implies, useful for model inference. Unlike torch.no_grad()
, it disables view tracking, so computations made using torch.inference_mode()
can not later be used in computations that require gradients (as opposed to torch.no_grad()
). Disabling view tracking further speeds up model forward passes.
Motivation
The motivation here is to implement torch.inference_mode()
inside the Trainer.predict()
logic, replacing any occurrences of torch.no_grad()
. Because this code right now is contextually used for inference, I don't believe the lack of view tracking represents a significant drawback/limitation for downstream code. Conceptually, clients of Trainer.predict()
should not be relying on these tensors for downstream code that requires gradient calculation. However, clients of this code are likely relying on pytorch lightning to transparently scale and speed up their own code, and using torch.inference_mode()
will further this goal.
Pitch
change
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with torch.no_grad():
return self.predict_loop.run()
to
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with torch.inference_mode():
return self.predict_loop.run()
I think this would be the only change to make
cc @Borda