diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 010ecdf17cd0..b68dece68d31 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -327,6 +327,18 @@ def __init__( self.include_self = include_self self.greater_or_equal = greater_or_equal + def _get_filename_pattern(self, global_step: Optional[int]) -> str: + if self.filename_pattern is None: + filename_pattern = self.setup_filename_pattern( + with_prefix=len(self.filename_prefix) > 0, + with_score=self.score_function is not None, + with_score_name=self.score_name is not None, + with_global_step=global_step is not None, + ) + else: + filename_pattern = self.filename_pattern + return filename_pattern + def reset(self) -> None: """Method to reset saved checkpoint names. @@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None: name = k checkpoint = checkpoint[name] - if self.filename_pattern is None: - filename_pattern = self.setup_filename_pattern( - with_prefix=len(self.filename_prefix) > 0, - with_score=self.score_function is not None, - with_score_name=self.score_name is not None, - with_global_step=global_step is not None, - ) - else: - filename_pattern = self.filename_pattern + filename_pattern = self._get_filename_pattern(global_step) filename_dict = { "filename_prefix": self.filename_prefix, @@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None: raise TypeError(f"Object {type(obj)} should have `{attr}` method") @staticmethod - def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None: + def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None: """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``. Args: to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` - checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, - "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain - directly corresponding state_dict. + checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g. + `{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key, + then checkpoint can contain directly corresponding state_dict. kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables the user to load part of the pretrained model (useful for example, in Transfer Learning) Examples: .. code-block:: python + import tempfile + from pathlib import Path + import torch + from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint + trainer = Engine(lambda engine, batch: None) - handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) - model = torch.nn.Linear(3, 3) - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - to_save = {"weights": model, "optimizer": optimizer} - trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) - trainer.run(torch.randn(10, 1), 5) - to_load = to_save - checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" - checkpoint = torch.load(checkpoint_fp) - Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) + with tempfile.TemporaryDirectory() as tmpdirname: + handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True) + + model = torch.nn.Linear(3, 3) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + to_save = {"weights": model, "optimizer": optimizer} + + trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) + trainer.run(torch.randn(10, 1), 5) + + to_load = to_save + checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt' + checkpoint = torch.load(checkpoint_fp) + Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) - # or using a string for checkpoint filepath + # or using a string for checkpoint filepath - to_load = to_save - checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" - Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp) + to_load = to_save + checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt' + Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp) Note: If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or @@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html """ - if isinstance(checkpoint, str): + if isinstance(checkpoint, (str, Path)): checkpoint_obj = torch.load(checkpoint) else: checkpoint_obj = checkpoint Checkpoint._check_objects(to_load, "load_state_dict") - if not isinstance(checkpoint, (collections.Mapping, str)): + if not isinstance(checkpoint, (collections.Mapping, str, Path)): raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}") if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]): @@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None: raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint") _load_object(obj, checkpoint_obj[k]) + def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None: + """Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as + name, score and global state can be configured. + + Args: + to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` + load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables + the user to load part of the pretrained model (useful for example, in Transfer Learning) + filename_components: Filename components used to define the checkpoint file path. + Keyword arguments accepted are `name`, `score` and `global_state`. + + Examples: + .. code-block:: python + + import tempfile + + import torch + + from ignite.engine import Engine, Events + from ignite.handlers import ModelCheckpoint, Checkpoint + + trainer = Engine(lambda engine, batch: None) + + with tempfile.TemporaryDirectory() as tmpdirname: + checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True) + + model = torch.nn.Linear(3, 3) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + to_save = {"weights": model, "optimizer": optimizer} + + trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save) + trainer.run(torch.randn(10, 1), 5) + + to_load = to_save + # load checkpoint myprefix_checkpoint_40.pt + checkpoint.load_objects(to_load=to_load, global_step=40) + + Note: + If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or + `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``). + + .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ + torch.nn.parallel.DistributedDataParallel.html + .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html + """ + + global_step = filename_components.get("global_step", None) + + filename_pattern = self._get_filename_pattern(global_step) + + checkpoint = self._setup_checkpoint() + name = "checkpoint" + if len(checkpoint) == 1: + for k in checkpoint: + name = k + name = filename_components.get("name", name) + score = filename_components.get("score", None) + + filename_dict = { + "filename_prefix": self.filename_prefix, + "ext": self.ext, + "name": name, + "score_name": self.score_name, + "score": score, + "global_step": global_step, + } + + checkpoint_fp = filename_pattern.format(**filename_dict) + + path = self.save_handler.dirname / checkpoint_fp + + load_kwargs = {} if load_kwargs is None else load_kwargs + + Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs) + def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": """Method returns state dict with saved items: list of ``(priority, filename)`` pairs. Can be used to save internal state of the class. diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index fa2570e605c2..9a751ecdab3e 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -576,6 +576,9 @@ def test_model_checkpoint_simple_recovery(dirname): assert fname.exists() loaded_objects = torch.load(fname) assert loaded_objects == model.state_dict() + to_load = {"model": DummyModel()} + h.reload_objects(to_load=to_load, global_step=1) + assert to_load["model"].state_dict() == model.state_dict() def test_model_checkpoint_simple_recovery_from_existing_non_empty(dirname): @@ -600,6 +603,9 @@ def _test(ext, require_empty): assert previous_fname.exists() loaded_objects = torch.load(fname) assert loaded_objects == model.state_dict() + to_load = {"model": DummyModel()} + h.reload_objects(to_load=to_load, global_step=1) + assert to_load["model"].state_dict() == model.state_dict() fname.unlink() _test(".txt", require_empty=True) @@ -1118,6 +1124,7 @@ def _get_multiple_objs_to_save(): assert str(dirname / _PREFIX) in str(fname) assert fname.exists() Checkpoint.load_objects(to_save, str(fname)) + Checkpoint.load_objects(to_save, fname) fname.unlink() # case: multiple objects