Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,10 @@ def remove(self, filename: str) -> None:


class ModelCheckpoint(Checkpoint):
"""ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to
"""ModelCheckpoint handler is a :class:`~ignite.handlers.checkpoint.Checkpoint` handler that can be used
to periodically save objects to disk only. If needed to store checkpoints to
another storage type, please consider :class:`~ignite.handlers.checkpoint.Checkpoint`.
It also provides `last_checkpoint` attribute to show the last saved checkpoint.

This handler expects two arguments:

Expand Down Expand Up @@ -879,7 +881,7 @@ class ModelCheckpoint(Checkpoint):
:class:`~ignite.engine.engine.Engine` object, and return a score (`float`). Objects with highest scores
will be retained.
score_name: if ``score_function`` not None, it is possible to store its value using
`score_name`. See Notes for more details.
`score_name`. See Examples of :class:`~ignite.handlers.checkpoint.Checkpoint` for more details.
n_saved: Number of objects that should be kept on disk. Older files will be removed. If set to
`None`, all objects are kept.
atomic: If True, objects are serialized to a temporary file, and then moved to final
Expand Down Expand Up @@ -909,21 +911,24 @@ class ModelCheckpoint(Checkpoint):
with :class:`~ignite.handlers.checkpoint.Checkpoint`

Examples:
.. code-block:: python
.. testcode:: python

import os
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from torch import nn
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True, require_empty=False)
model = nn.Linear(3, 3)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
trainer.run([0, 1, 2, 3, 4], max_epochs=6)
os.listdir('/tmp/models')
# ['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
handler.last_checkpoint
# ['/tmp/models/myprefix_mymodel_30.pt']
print(sorted(os.listdir('/tmp/models')))
print(handler.last_checkpoint)

.. testoutput:: python

['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
/tmp/models/myprefix_mymodel_30.pt
"""

def __init__(
Expand Down