-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Create a ModelCheckpointBase callback, and have the existing checkpoint callback extend it
Motivation
The model checkpoint callback is growing in complexity. Features that have been recently added or will soon be proposed:
- Checkpoint every n train steps: [feat] Support iteration-based checkpointing in model checkpoint callback #6146
- Checkpointing using a time-based interval during training: Support time-based checkpointing trigger #6286
- Checkpointing on train end to fix this hack
https://github.com/PyTorchLightning/pytorch-lightning/blob/680e83adab38c2d680b138bdc39d48fc35c0cb58/pytorch_lightning/trainer/training_loop.py#L152-L163
The decision was made in #6146 to keep these triggers mutually exclusive, at least based on the phase they run in. Why? It's very hard to get the state management right. For instance, the monitor
might be added for something that's available only during validation, but the checkpoint callback is configured to run during training too, and crashes when it tries to look up the monitor
key in the available metrics for tracking. Tracking top-K models and scores is another huge pain. Supporting multiple monitor metrics on top of this is another beast.
cc @Borda @carmocca @awaelchli @ninginthecloud @jjenniferdai @rohitgr7
Pitch
Move the existing logic for the following into a base class:
- Core saving functionality
- Top-K model management
- formatting checkpoint names
- validation (though sub-classes should override this)
And have thin wrappers on top which extend this class and implement callback hook(s) for when to save the checkpoint.
Alternatives
The checkpoint callback gets bigger and bigger as we add more features to it.