Skip to content

use torch.inference_mode() in Trainer.predict #11018

@davidegraff

Description

@davidegraff

🚀 Feature

pytorch 1.9 introduced a new decorator/context manager specifically for model inference torch.inference_mode(). Like torch.no_grad() it disables gradient tracking for faster forward passes through a model and is, as the name implies, useful for model inference. Unlike torch.no_grad(), it disables view tracking, so computations made using torch.inference_mode() can not later be used in computations that require gradients (as opposed to torch.no_grad()). Disabling view tracking further speeds up model forward passes.

Motivation

The motivation here is to implement torch.inference_mode() inside the Trainer.predict() logic, replacing any occurrences of torch.no_grad(). Because this code right now is contextually used for inference, I don't believe the lack of view tracking represents a significant drawback/limitation for downstream code. Conceptually, clients of Trainer.predict() should not be relying on these tensors for downstream code that requires gradient calculation. However, clients of this code are likely relying on pytorch lightning to transparently scale and speed up their own code, and using torch.inference_mode() will further this goal.

Pitch

change

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
        self.reset_predict_dataloader(self.lightning_module)
        # reset trainer on this loop and all child loops in case user connected a custom loop
        self.predict_loop.trainer = self
        with torch.no_grad():
            return self.predict_loop.run()

to

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
        self.reset_predict_dataloader(self.lightning_module)
        # reset trainer on this loop and all child loops in case user connected a custom loop
        self.predict_loop.trainer = self
        with torch.inference_mode():
            return self.predict_loop.run()

I think this would be the only change to make

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions