diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index ad45dc63e7fb..2f7e6843149f 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -507,14 +507,14 @@ def _check_objects(objs: Mapping, attr: str) -> None: raise TypeError(f"Object {type(obj)} should have `{attr}` method") @staticmethod - def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: + def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None: """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``. Args: to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` - checkpoint: a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, - "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly - corresponding state_dict. + checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, + "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain + directly corresponding state_dict. kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables the user to load part of the pretrained model (useful for example, in Transfer Learning) @@ -537,6 +537,12 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) + # or using a string for checkpoint filepath + + to_load = to_save + checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" + Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp) + Note: If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``). @@ -544,11 +550,16 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html - """ + + if isinstance(checkpoint, str): + checkpoint_obj = torch.load(checkpoint) + else: + checkpoint_obj = checkpoint + Checkpoint._check_objects(to_load, "load_state_dict") - if not isinstance(checkpoint, collections.Mapping): - raise TypeError(f"Argument checkpoint should be a dictionary, but given {type(checkpoint)}") + if not isinstance(checkpoint, (collections.Mapping, str)): + raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}") if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]): warnings.warn("kwargs contains keys other than strict and these will be ignored") @@ -557,22 +568,22 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: if len(to_load) == 1: # single object and checkpoint is directly a state_dict key, obj = list(to_load.items())[0] - if key not in checkpoint: + if key not in checkpoint_obj: if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module - obj.load_state_dict(checkpoint, strict=is_state_dict_strict) + obj.load_state_dict(checkpoint_obj, strict=is_state_dict_strict) return # multiple objects to load for k, obj in to_load.items(): - if k not in checkpoint: + if k not in checkpoint_obj: raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint") if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module if isinstance(obj, torch.nn.Module): - obj.load_state_dict(checkpoint[k], strict=is_state_dict_strict) + obj.load_state_dict(checkpoint_obj[k], strict=is_state_dict_strict) else: - obj.load_state_dict(checkpoint[k]) + obj.load_state_dict(checkpoint_obj[k]) def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": """Method returns state dict with saved items: list of ``(priority, filename)`` pairs. diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 39863255d8b8..7897da355f0b 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1070,7 +1070,7 @@ def test_save_model_optimizer_lr_scheduler_with_validation(dirname): def test_checkpoint_load_objects(): - with pytest.raises(TypeError, match=r"Argument checkpoint should be a dictionary"): + with pytest.raises(TypeError, match=r"Argument checkpoint should be a string or a dictionary"): Checkpoint.load_objects({}, []) with pytest.raises(TypeError, match=r"should have `load_state_dict` method"): @@ -1107,6 +1107,17 @@ def _get_multiple_objs_to_save(): trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) + # case: load from filepath + handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) + to_save = _get_multiple_objs_to_save() + handler(trainer, to_save) + fname = handler.last_checkpoint + assert isinstance(fname, str) + assert os.path.join(dirname, _PREFIX) in fname + assert os.path.exists(fname) + Checkpoint.load_objects(to_save, fname) + os.remove(fname) + # case: multiple objects handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_multiple_objs_to_save() @@ -1142,6 +1153,7 @@ def _get_multiple_objs_to_save(): assert os.path.exists(fname) loaded_objects = torch.load(fname) Checkpoint.load_objects(to_save, loaded_objects) + os.remove(fname) def test_load_checkpoint_with_different_num_classes(dirname):