-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
📝 Summary
I am trying to train a custom policy module using PyTorch Lightning, where my model components (policy, preprocessor, postprocessor) all inherit from huggingface_hub.PyTorchModelHubMixin.
This mixin provides:
_save_pretrained()_from_pretrained()
which work similarly to HuggingFace’ssave_pretrained()andfrom_pretrained(), and are very convenient for packaging model weights + config.
However, I am not sure how to properly integrate these HF-style save/load utilities into Lightning's standard training flow — especially Lightning's ModelCheckpoint callback.
📦 Minimal example
class PolicyModule(LightningModule):
def __init__(
self,
policy: BasePolicy,
preprocessor: DataProcessorPipeline,
postprocessor: DataProcessorPipeline,
**kwargs,
):
super().__init__()
self.save_hyperparameters(
logger=False, ignore=["policy", "preprocessor", "postprocessor"]
)
self.policy = policy
self.preprocessor = preprocessor
self.postprocessor = postprocessor
def training_step(self, batch, batch_idx):
batch = self.preprocessor(batch)
loss, loss_dict = self.policy(batch)
self.log_dict(
{"train/_loss": loss},
on_step=True, on_epoch=True, prog_bar=True, sync_dist=True,
)
self.log_dict(
{f"train/{k}": v for k, v in loss_dict.items()},
on_step=True, on_epoch=True, sync_dist=True,
)
return loss❗The problem
How to make Lightning call save_pretrained() during checkpointing?
🙏 Additional context
PyTorch Lightning provides many excellent callback tools—especially ModelCheckpoint—which greatly simplify training workflows. Ideally, I would like to remain fully within the standard Lightning Trainer + callbacks framework, while still taking advantage of HuggingFace-style save_pretrained() / from_pretrained() for model components.
Thanks in advance for any guidance or best practices!
Pitch
No response
Alternatives
No response
Additional context
No response
cc @lantiga