Skip to content

Commit 6114fef

Browse files
author
Yerdos Ordabayev
committed
no-op _add_instantiators
1 parent e8163ee commit 6114fef

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

cellarium/ml/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ def __init__(self, args: ArgsType = None) -> None:
275275
args=args,
276276
)
277277

278+
def _add_instantiators(self) -> None:
279+
# disable breaking dependency injection support change introduced in PyTorch Lightning 2.3
280+
# https://github.com/Lightning-AI/pytorch-lightning/pull/18105
281+
pass
282+
278283
def instantiate_classes(self) -> None:
279284
with torch.device("meta"):
280285
# skip the initialization of model parameters

cellarium/ml/core/module.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ def __init__(
6161
self.save_hyperparameters(logger=False)
6262
self.pipeline: CellariumPipeline | None = None
6363

64+
if optim_fn is None:
65+
# Starting from PyTorch Lightning 2.3, automatic optimization doesn't allow to return None
66+
# from the training_step during distributed training. https://github.com/Lightning-AI/pytorch-lightning/pull/19918
67+
# Thus, we need to use manual optimization for the No Optimizer case.
68+
self.automatic_optimization = False
69+
6470
def configure_model(self) -> None:
6571
"""
6672
.. note::
@@ -156,6 +162,14 @@ def training_step( # type: ignore[override]
156162
if loss is not None:
157163
# Logging to TensorBoard by default
158164
self.log("train_loss", loss)
165+
166+
if not self.automatic_optimization:
167+
# Note, that running .step() is necessary for incrementing the global step even though no backpropagation
168+
# is performed.
169+
no_optimizer = self.optimizers()
170+
assert isinstance(no_optimizer, pl.core.optimizer.LightningOptimizer)
171+
no_optimizer.step()
172+
159173
return loss
160174

161175
def forward(self, batch: dict[str, np.ndarray | torch.Tensor]) -> dict[str, np.ndarray | torch.Tensor]:

0 commit comments

Comments
 (0)