Skip to content

Commit 4cc05b2

Browse files
authored
Support optimizer step progress tracking with manual optimization (#11848)
1 parent 963adc7 commit 4cc05b2

File tree

11 files changed

+68
-83
lines changed

11 files changed

+68
-83
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7878
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))
7979

8080

81+
- Added support for optimizer step progress tracking with manual optimization ([#11848](https://github.com/PyTorchLightning/pytorch-lightning/pull/11848))
82+
83+
8184
- Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711))
8285

8386

pytorch_lightning/core/optimizer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def __init__(self, optimizer: Optimizer):
5353
self._optimizer = optimizer
5454
self._strategy: Optional[pl.strategies.Strategy] = None
5555
self._optimizer_idx = 0
56+
# to inject logic around the optimizer step, particularly useful with manual optimization
57+
self._on_before_step = do_nothing_closure
58+
self._on_after_step = do_nothing_closure
5659

5760
@property
5861
def optimizer(self) -> Optimizer:
@@ -154,6 +157,8 @@ def closure_dis():
154157
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
155158
opt_dis.step(closure=closure_dis)
156159
"""
160+
self._on_before_step()
161+
157162
if closure is None:
158163
closure = do_nothing_closure
159164
profiler_action = "optimizer_step_without_closure"
@@ -166,7 +171,11 @@ def closure_dis():
166171
assert self._strategy is not None
167172
assert self._strategy.lightning_module is not None
168173
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
169-
return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
174+
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
175+
176+
self._on_after_step()
177+
178+
return step_output
170179

171180

172181
def _init_optimizers_and_lr_schedulers(

pytorch_lightning/loops/fit_loop.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def __init__(
7171
@property
7272
def global_step(self) -> int:
7373
"""Returns the global step."""
74-
return self.epoch_loop.global_step
74+
lightning_module = self.trainer.lightning_module
75+
if lightning_module is None or lightning_module.automatic_optimization:
76+
return self.epoch_loop.global_step
77+
return self.epoch_loop.batch_loop.manual_loop.optim_step_progress.total.completed
7578

7679
@global_step.setter
7780
def global_step(self, value: int) -> None:
@@ -96,7 +99,7 @@ def split_idx(self) -> int:
9699
@property
97100
def min_steps(self) -> Optional[int]:
98101
# TODO(@justusschock): Why aren't we using the attribute in this class?
99-
"""Returns the minimum numnber of steps to run."""
102+
"""Returns the minimum number of steps to run."""
100103
return self.epoch_loop.min_steps
101104

102105
@min_steps.setter

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
from torch import Tensor
1818

19+
from pytorch_lightning.core.optimizer import do_nothing_closure
1920
from pytorch_lightning.loops import Loop
2021
from pytorch_lightning.loops.optimization.closure import OutputResult
2122
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
23+
from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker
2224
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2325
from pytorch_lightning.utilities.types import STEP_OUTPUT
2426

@@ -74,6 +76,10 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
7476

7577
def __init__(self) -> None:
7678
super().__init__()
79+
# since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
80+
# `OptimizationProgress`
81+
self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
82+
7783
self._done: bool = False
7884
self._hiddens: Optional[Any] = None
7985
self._output: _OUTPUTS_TYPE = {}
@@ -85,6 +91,12 @@ def done(self) -> bool:
8591
def reset(self) -> None:
8692
self._done = False
8793

94+
def on_run_start(self, *_: Any, **__: Any) -> None:
95+
# inject logic around the optimizer step
96+
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
97+
lightning_optimizer._on_before_step = self._on_before_step
98+
lightning_optimizer._on_after_step = self._on_after_step
99+
88100
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
89101
"""Performs the training step for manual optimization.
90102
@@ -126,4 +138,14 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
126138
def on_run_end(self) -> _OUTPUTS_TYPE:
127139
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
128140
output, self._output = self._output, {} # free memory
141+
# reset logic around the optimizer step
142+
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
143+
lightning_optimizer._on_before_step = do_nothing_closure
144+
lightning_optimizer._on_after_step = do_nothing_closure
129145
return output
146+
147+
def _on_before_step(self) -> None:
148+
self.optim_step_progress.increment_ready()
149+
150+
def _on_after_step(self) -> None:
151+
self.optim_step_progress.increment_completed()

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def backward_fn(loss: Tensor) -> None:
317317
return backward_fn
318318

319319
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
320-
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
320+
"""Toggles the optimizer to ensure the correct one is used and prevent dangling grads.
321321
322322
Args:
323323
opt_idx: the index of the optimizer to use
@@ -348,7 +348,7 @@ def _optimizer_step(
348348
opt_idx: the index of the current :param:`optimizer`
349349
batch_idx: the index of the current batch
350350
train_step_and_backward_closure: the closure function performing the train step and computing the
351-
gradients. By default called by the optimizer (if possible)
351+
gradients. By default, called by the optimizer (if possible)
352352
"""
353353
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
354354

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
348348
checkpoint = {
349349
# the epoch is saved for compatibility but it's not relevant for restoration
350350
"epoch": self.trainer.current_epoch,
351-
"global_step": self.trainer.global_step + 1,
351+
"global_step": self.trainer.global_step + model.automatic_optimization,
352352
"pytorch-lightning_version": pl.__version__,
353353
"state_dict": self._get_lightning_module_state_dict(),
354354
"loops": self._get_loops_state_dict(),

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:
20102010

20112011
@property
20122012
def lightning_module(self) -> "pl.LightningModule":
2013+
# TODO: this is actually an optional return
20132014
return self.strategy.lightning_module
20142015

20152016
@property

tests/core/test_lightning_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def test_state(tmpdir):
152152
lightning_dict = {
153153
k: v
154154
for k, v in lightning_optimizer.__dict__.items()
155-
if k not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module"}
155+
if k
156+
not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
156157
}
157158

158159
assert lightning_dict == optimizer.__dict__

tests/loops/test_loop_state_dict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def test_loops_state_dict_structure():
5959
},
6060
"epoch_loop.batch_loop.state_dict": {},
6161
"epoch_loop.batch_loop.manual_loop.state_dict": {},
62+
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
63+
"total": {"ready": 0, "completed": 0},
64+
"current": {"ready": 0, "completed": 0},
65+
},
6266
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
6367
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
6468
"optimizer": {

tests/loops/test_loops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,10 @@ def configure_optimizers_multiple(self):
512512
},
513513
"epoch_loop.batch_loop.state_dict": ANY,
514514
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
515+
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
516+
"total": {"ready": 0, "completed": 0},
517+
"current": {"ready": 0, "completed": 0},
518+
},
515519
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
516520
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
517521
"optimizer_position": stop_optimizer,
@@ -681,6 +685,10 @@ def train_dataloader(self):
681685
},
682686
"epoch_loop.batch_loop.state_dict": ANY,
683687
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
688+
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
689+
"total": {"ready": 0, "completed": 0},
690+
"current": {"ready": 0, "completed": 0},
691+
},
684692
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
685693
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
686694
"optimizer_position": n_optimizers,

0 commit comments

Comments
 (0)