Skip to content

Can I nest LightningModules inside child modules? #20053

@jackdent

Description

@jackdent

Bug description

Suppose I have a LightningModule (parent) that contains a nn.Module (child), which in turn contains another LightningModule (grandchild). Calling .log inside the LightningModule (the grandchild) results in the following warning:

You are trying to self.log() but the self.trainer reference is not registered on the model yet. This is most likely because the model hasn't been passed to the Trainer

The trainer is only set on the direct children of the parent LightningModule, not all the descendants, since the trainer.setter uses self.children() rather than self.modules():

@trainer.setter
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer # type: ignore[assignment]
self._trainer = trainer

What version are you seeing the problem on?

master

How to reproduce the bug

# %%

import lightning as L
import torch
from torch import nn


class GrandChild(L.LightningModule):
    def dummy_log(self):
        self.log("foo", 1)


class Child(nn.Module):
    def __init__(self):
        super().__init__()
        self.module = nn.Linear(1, 1)
        self.grandchild = GrandChild()

    def forward(self):
        self.grandchild.dummy_log()
        return 1


class Parent(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.child = Child()

    def training_step(self, batch, batch_idx):
        return self.child()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.randn(10, 1)), batch_size=1
        )


# model
parent = Parent()

# train model
trainer = L.Trainer()
trainer.fit(model=parent)
optimizer = parent.configure_optimizers()

loss = parent.training_step(batch=None, batch_idx=0)

Error messages and logs

You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-80GB
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.2.1
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.1
    - torch: 2.3.1
    - torchmetrics: 1.3.2
    - torchvision: 0.18.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.15.0-113-generic
    - version: update Win CI req. #122 #123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024

More info

No response

cc @carmocca @justusschock @awaelchli @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions