Skip to content

Hook to transform the model before loading the weights #14447

@awaelchli

Description

@awaelchli

🚀 Feature

Motivation

This is a request from a user on Slack.
In their use case, they need to transform the model early in the trainer execution, after the checkpoint was loaded, but early enough before the weights get copied into the model (in case of executing trainer.fit(model, ckpt_path=...)). This is currently only done in a special case of the QuantizationAwareTraining callback.

Pitch

Provide a hook that runs before the model gets reloaded, but after the weights have been loaded from the file, i.e., in the sequence below the hook should run roughly where self._checkpoint_connector._restore_quantization_callbacks() runs.

https://github.com/Lightning-AI/lightning/blob/291267c3bff8054ec438960857c9f2fec1d54899/src/pytorch_lightning/trainer/trainer.py#L1071-L1079

The hook should take as input the checkpoint dict, so that the user can load their metadata.

def on_resume_start(self, lightning_module, trainer, checkpoint):
    """Do something with the model before restoration of the trainer/model state"""

Alternatives

One could make the QuantizationAwareTraining._load_before_model hook public, but this is limited to the use of the quantization callback only.

Additional context

Slack conversation


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.

cc @Borda @tchaton @justusschock @awaelchli @carmocca @ananthsub @ninginthecloud @jjenniferdai @rohitgr7 @akihironitta

Metadata

Metadata

Assignees

No one assigned

    Labels

    designIncludes a design discussionfeatureIs an improvement or enhancementlightningmodulepl.LightningModuleplGeneric label for PyTorch Lightning package

    Type

    No type

    Projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions