Skip to content

fabric.logger.log_graph() seems to do nothing #17844

@shihaoyin

Description

@shihaoyin

Bug description

fabric.logger.log_graph() seems to do nothing.
My code looks like below, but no graph record in tensorboard log.

fabric = Fabric(loggers=TensorBoardLogger(root_dir="./logs", name=None))
model, optimizer = fabric.setup(model, optimizer)
fabric.logger.log_graph(model=model, input_array=next(iter(val_loader))[0])

I explored the source code of fabric.logger.log_graph()

@rank_zero_only
    def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
        model_example_input = getattr(model, "example_input_array", None)
        input_array = model_example_input if input_array is None else input_array

        if input_array is None:
            rank_zero_warn(
                "Could not log computational graph to TensorBoard: The `model.example_input_array` attribute"
                " is not set or `input_array` was not given."
            )
        elif not isinstance(input_array, (Tensor, tuple)):
            rank_zero_warn(
                "Could not log computational graph to TensorBoard: The `input_array` or `model.example_input_array`"
                f" has type {type(input_array)} which can't be traced by TensorBoard. Make the input array a tuple"
                f" representing the positional arguments to the model's `forward()` implementation."
            )
        elif callable(getattr(model, "_on_before_batch_transfer", None)) and callable(
            getattr(model, "_apply_batch_transfer_handler", None)
        ):
            # this is probably is a LightningModule
            input_array = model._on_before_batch_transfer(input_array)  # type: ignore[operator]
            input_array = model._apply_batch_transfer_handler(input_array)  # type: ignore[operator]
            self.experiment.add_graph(model, input_array)

It seems self.experiment.add_graph(model, input_array) should be reduced one indent, because that line no executed.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version: 2.0.3
#- PyTorch Version: 2.0.1
#- Python version: 3.10

cc @awaelchli @Borda @Blaizzy @carmocca @justusschock

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions