Skip to content

Fix main progress bar counter when val_check_interval=int and check_val_every_n_epoch=None #12832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,8 @@ How often within one training epoch to check the validation set.
Can specify as float or int.

- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
- pass an ``int`` to check after a fixed number of training batches.
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.

.. testcode::

Expand All @@ -1489,10 +1490,13 @@ Can specify as float or int.
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)

# check validation set every 1000 training batches
# check validation set every 1000 training batches in the current epoch
trainer = Trainer(val_check_interval=1000)

# check validation set every 1000 training batches across complete epochs or during iteration-based training
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000)
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)


.. code-block:: python
Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))


-
- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)


### Deprecated
Expand Down Expand Up @@ -345,6 +345,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `Trainer.predict(return_predictions=False)` to track prediction's batch_indices ([#13629](https://github.com/Lightning-AI/lightning/pull/13629))


- Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)


## [1.6.5] - 2022-07-13

### Fixed
Expand Down
21 changes: 21 additions & 0 deletions src/pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,27 @@ def total_val_batches(self) -> Union[int, float]:
assert self._trainer is not None
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0

@property
def total_batches_current_epoch(self) -> Union[int, float]:
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
assert self._trainer is not None

if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_check_batch = self.trainer.val_check_batch
if self.trainer.check_val_every_n_epoch is None:
train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1
val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - (
train_batches_processed // val_check_batch
)
else:
val_checks_per_epoch = total_train_batches // val_check_batch

total_val_batches = total_val_batches * val_checks_per_epoch

return total_train_batches + total_val_batches

def has_dataloader_changed(self, dataloader_idx: int) -> bool:
old_dataloader_idx = self._current_eval_dataloader_idx
self._current_eval_dataloader_idx = dataloader_idx
Expand Down
12 changes: 3 additions & 9 deletions src/pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,9 @@ def on_sanity_check_end(self, trainer, pl_module):
self.refresh()

def on_train_epoch_start(self, trainer, pl_module):
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch

total_batches = total_train_batches + total_val_batches

total_batches = self.total_batches_current_epoch
train_description = self._get_train_description(trainer.current_epoch)

if self.main_progress_bar_id is not None and self._leave:
self._stop_progress()
self._init_progress(trainer)
Expand All @@ -343,6 +336,7 @@ def on_train_epoch_start(self, trainer, pl_module):
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)

self.refresh()

def on_validation_batch_start(
Expand Down
8 changes: 1 addition & 7 deletions src/pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,7 @@ def on_train_start(self, *_: Any) -> None:
self.main_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
total_batches = self.total_batches_current_epoch
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

Expand Down
13 changes: 6 additions & 7 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
if self.restarting and self._should_check_val_fx():
# skip training and run validation in `on_advance_end`
return
# we are going to train first so the val loop does not need to restart
Expand Down Expand Up @@ -235,7 +235,7 @@ def on_advance_end(self) -> None:
# -----------------------------------------
# VALIDATE IF NEEDED
# -----------------------------------------
should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch)
should_check_val = self._should_check_val_fx()
if should_check_val:
self.trainer.validating = True
self._run_validation()
Expand Down Expand Up @@ -496,13 +496,14 @@ def _should_check_val_epoch(self) -> bool:
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
def _should_check_val_fx(self) -> bool:
"""Decide if we should run validation."""
if not self._should_check_val_epoch():
return False

# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float("inf")
is_last_batch = self.batch_progress.is_last_batch
if is_last_batch and is_infinite_dataset:
return True

Expand All @@ -512,13 +513,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
# else condition it based on the batch_idx of the current epoch
current_iteration = (
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
)
current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0

return is_val_check_batch
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ def __init__(
val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches. An ``int`` value can only be higher than the number of training batches when
``check_val_every_n_epoch=None``.
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
across epochs or during iteration-based training.
Default: ``1.0``.

enable_model_summary: Whether to enable model summarization by default.
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_pytorch/callbacks/progress/test_base_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.trainer import Trainer


def test_main_progress_bar_with_val_check_interval_int():
"""Test the main progress bar count when val_check_interval=int and check_val_every_n_epoch=None."""
train_batches = 5
trainer = Trainer(
limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None
)
model = BoringModel()
trainer.progress_bar_callback.setup(trainer, model)
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model)
trainer.reset_train_dataloader()
trainer.reset_val_dataloader()
expected = [15, 25, 25, 15]

for count in expected:
assert trainer.progress_bar_callback.total_batches_current_epoch == count
trainer.fit_loop.epoch_loop.batch_progress.total.ready += train_batches
11 changes: 7 additions & 4 deletions tests/tests_pytorch/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,20 @@ def test_val_check_interval_info_message(caplog, value):


@pytest.mark.parametrize("use_infinite_dataset", [True, False])
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset):
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2])
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset, accumulate_grad_batches):
data_samples_train = 4
max_epochs = 3
max_steps = data_samples_train * max_epochs
max_opt_steps = max_steps // accumulate_grad_batches

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def validation_step(self, *args):
self.validation_called_at_step.add(self.global_step)
self.validation_called_at_step.add(self.trainer.fit_loop.total_batch_idx + 1)
return super().validation_step(*args)

def train_dataloader(self):
Expand All @@ -89,16 +91,17 @@ def train_dataloader(self):
trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=1,
max_steps=max_steps,
max_steps=max_opt_steps,
val_check_interval=3,
check_val_every_n_epoch=None,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
)

trainer.fit(model)

assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs
assert trainer.global_step == max_steps
assert trainer.global_step == max_opt_steps
assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12]


Expand Down