Skip to content

BatchSizeFinder defining max validation batches for entire training loop #18394

@joncarter1

Description

@joncarter1

Bug description

When the BatchSizeFinder callback is used, the steps_per_trial parameter ends up defining how many validation batches to run during the entire length of training. This is a similar issue to that observed with the LR Finder (#17412).

What version are you seeing the problem on?

v2.0

How to reproduce the bug

(Adapted from @blainehoak #17412 )

import time
import torch
from torch.utils.data import DataLoader, Dataset

import lightning
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BatchSizeFinder


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class DummyDataModule(lightning.LightningDataModule):
    def __init__(
        self,
        length: int,
        size: int = 32,
        batch_size: int = 32,
    ):
        super().__init__()
        self.size = size
        self.length = length
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            RandomDataset(self.size, self.length),
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            RandomDataset(self.size, self.length),
            batch_size=self.batch_size,
            shuffle=False,
        )


class BoringModel(LightningModule):
    def __init__(self, size=32, lr=0.1):
        super().__init__()
        self.model = torch.nn.Linear(size, 2)
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        time.sleep(0.01)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        time.sleep(0.5)  # Making no. steps visible in progress bar
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=self.lr)


def run():
    STEPS = 13  # This ends up determining the number of validation steps
    LENGTH = 10_000
    datamodule = DummyDataModule(length=LENGTH, batch_size=32)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir="/tmp/lightning_logs",
        max_epochs=10,
        enable_model_summary=False,
        callbacks=[BatchSizeFinder(steps_per_trial=STEPS, max_trials=3)],
    )
    trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
    run()

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce GTX 1070
    - available: True
    - version: 11.8
  • Lightning:
    - lightning: 2.0.7
    - lightning-cloud: 0.5.37
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.2
    - torch: 2.0.1
    - torchmetrics: 0.11.4
    - torchvision: 0.15.2

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions