-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
🚀 Feature
Motivation
This is a request from a user on Slack.
In their use case, they need to transform the model early in the trainer execution, after the checkpoint was loaded, but early enough before the weights get copied into the model (in case of executing trainer.fit(model, ckpt_path=...)
). This is currently only done in a special case of the QuantizationAwareTraining callback.
Pitch
Provide a hook that runs before the model gets reloaded, but after the weights have been loaded from the file, i.e., in the sequence below the hook should run roughly where self._checkpoint_connector._restore_quantization_callbacks()
runs.
The hook should take as input the checkpoint dict, so that the user can load their metadata.
def on_resume_start(self, lightning_module, trainer, checkpoint):
"""Do something with the model before restoration of the trainer/model state"""
Alternatives
One could make the QuantizationAwareTraining._load_before_model
hook public, but this is limited to the use of the quantization callback only.
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.
cc @Borda @tchaton @justusschock @awaelchli @carmocca @ananthsub @ninginthecloud @jjenniferdai @rohitgr7 @akihironitta