Skip to content

Commit 948fbb7

Browse files
ericspodpre-commit-ci[bot]KumoLiu
authored
Torch and Pickle Safe Load Fixes (#8566)
Fixes #8355. ### Description This modifies the use of `torch.load` to load weights-only everywhere. This will change some behaviour in that data needs to be converted to tensors where possible, ie. converting Numpy arrays to tensors. This issue was reported by @h3rrr in GHSA-6vm5-6jv9-rjpj. This will also replace uses of `pickle` for saving and loading with `torch.save/load`. This also requires that Numpy arrays be converted but otherwise uses the restriction of weights-only to enforce safe loading. This issue was reported by @h3rrr in GHSA-p8cm-mm2v-gwjm. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]>
1 parent 401ea4a commit 948fbb7

File tree

12 files changed

+95
-77
lines changed

12 files changed

+95
-77
lines changed

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def get_nnunet_trainer(
133133
cudnn.benchmark = True
134134

135135
if pretrained_model is not None:
136-
state_dict = torch.load(pretrained_model)
136+
state_dict = torch.load(pretrained_model, weights_only=True)
137137
if "network_weights" in state_dict:
138138
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
139139
return nnunet_trainer
@@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
182182
parameters = []
183183

184184
checkpoint = torch.load(
185-
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
185+
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
186+
map_location=torch.device("cpu"),
187+
weights_only=True,
186188
)
187189
trainer_name = checkpoint["trainer_name"]
188190
configuration_name = checkpoint["init_args"]["configuration"]
@@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
192194
else None
193195
)
194196
if Path(model_training_output_dir).joinpath(model_name).is_file():
195-
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
197+
monai_checkpoint = torch.load(
198+
join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
199+
)
196200
if "network_weights" in monai_checkpoint.keys():
197201
parameters.append(monai_checkpoint["network_weights"])
198202
else:
@@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
383387
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
384388
)
385389

386-
nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
387-
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
390+
nnunet_checkpoint_final = torch.load(
391+
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
392+
)
393+
nnunet_checkpoint_best = torch.load(
394+
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
395+
)
388396

389397
nnunet_checkpoint = {}
390398
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
@@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
470478
if model_ckpt is None:
471479
return network
472480
else:
473-
state_dict = torch.load(model_ckpt)
481+
state_dict = torch.load(model_ckpt, weights_only=True)
474482
network.load_state_dict(state_dict[model_key_in_ckpt])
475483
return network
476484

@@ -534,7 +542,7 @@ def subfiles(
534542

535543
Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
536544

537-
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
545+
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
538546
latest_checkpoints: list[str] = subfiles(
539547
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
540548
)
@@ -545,7 +553,7 @@ def subfiles(
545553
epochs.sort()
546554
final_epoch: int = epochs[-1]
547555
monai_last_checkpoint: dict = torch.load(
548-
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
556+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
549557
)
550558

551559
best_checkpoints: list[str] = subfiles(
@@ -558,7 +566,7 @@ def subfiles(
558566
key_metrics.sort()
559567
best_key_metric: str = key_metrics[-1]
560568
monai_best_checkpoint: dict = torch.load(
561-
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
569+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
562570
)
563571

564572
nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]

monai/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
from .thread_buffer import ThreadBuffer, ThreadDataLoader
7979
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
8080
from .utils import (
81-
PICKLE_KEY_SUFFIX,
8281
affine_to_spacing,
8382
compute_importance_map,
8483
compute_shape_offset,

monai/data/dataset.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import collections.abc
1515
import math
16-
import pickle
1716
import shutil
1817
import sys
1918
import tempfile
@@ -22,9 +21,11 @@
2221
import warnings
2322
from collections.abc import Callable, Sequence
2423
from copy import copy, deepcopy
24+
from io import BytesIO
2525
from multiprocessing.managers import ListProxy
2626
from multiprocessing.pool import ThreadPool
2727
from pathlib import Path
28+
from pickle import UnpicklingError
2829
from typing import IO, TYPE_CHECKING, Any, cast
2930

3031
import numpy as np
@@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
207208
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
208209
errors. If in doubt, it is advisable to clear the cache directory.
209210
211+
Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
212+
be converted to tensors, however any other object type returned by transforms will not be loadable since
213+
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
214+
Legacy cache files may not be loadable and may need to be recomputed.
215+
210216
Lazy Resampling:
211217
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
212218
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
@@ -248,8 +254,8 @@ def __init__(
248254
this arg is used by `torch.save`, for more details, please check:
249255
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
250256
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
251-
pickle_protocol: can be specified to override the default protocol, default to `2`.
252-
this arg is used by `torch.save`, for more details, please check:
257+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
258+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
253259
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
254260
hash_transform: a callable to compute hash from the transform information when caching.
255261
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):
371377

372378
if hashfile is not None and hashfile.is_file(): # cache hit
373379
try:
374-
return torch.load(hashfile, weights_only=False)
380+
return torch.load(hashfile, weights_only=True)
375381
except PermissionError as e:
376382
if sys.platform != "win32":
377383
raise e
378-
except RuntimeError as e:
379-
if "Invalid magic number; corrupt file" in str(e):
384+
except (UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed
385+
if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError):
380386
warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
381387
hashfile.unlink()
382388
else:
@@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
392398
with tempfile.TemporaryDirectory() as tmpdirname:
393399
temp_hash_file = Path(tmpdirname) / hashfile.name
394400
torch.save(
395-
obj=_item_transformed,
401+
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
396402
f=temp_hash_file,
397403
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
398404
pickle_protocol=self.pickle_protocol,
@@ -455,8 +461,8 @@ def __init__(
455461
this arg is used by `torch.save`, for more details, please check:
456462
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
457463
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
458-
pickle_protocol: can be specified to override the default protocol, default to `2`.
459-
this arg is used by `torch.save`, for more details, please check:
464+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
465+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
460466
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
461467
hash_transform: a callable to compute hash from the transform information when caching.
462468
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -531,7 +537,7 @@ def __init__(
531537
hash_func: Callable[..., bytes] = pickle_hashing,
532538
db_name: str = "monai_cache",
533539
progress: bool = True,
534-
pickle_protocol=pickle.HIGHEST_PROTOCOL,
540+
pickle_protocol=DEFAULT_PROTOCOL,
535541
hash_transform: Callable[..., bytes] | None = None,
536542
reset_ops_id: bool = True,
537543
lmdb_kwargs: dict | None = None,
@@ -551,8 +557,9 @@ def __init__(
551557
defaults to `monai.data.utils.pickle_hashing`.
552558
db_name: lmdb database file name. Defaults to "monai_cache".
553559
progress: whether to display a progress bar.
554-
pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
555-
https://docs.python.org/3/library/pickle.html#pickle-protocols
560+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
561+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
562+
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
556563
hash_transform: a callable to compute hash from the transform information when caching.
557564
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
558565
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
@@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
594601
super().set_data(data=data)
595602
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
596603

604+
def _safe_serialize(self, val):
605+
out = BytesIO()
606+
torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol)
607+
out.seek(0)
608+
return out.read()
609+
610+
def _safe_deserialize(self, val):
611+
return torch.load(BytesIO(val), map_location="cpu", weights_only=True)
612+
597613
def _fill_cache_start_reader(self, show_progress=True):
598614
"""
599615
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
@@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
619635
continue
620636
if val is None:
621637
val = self._pre_transform(deepcopy(item)) # keep the original hashed
622-
val = pickle.dumps(val, protocol=self.pickle_protocol)
638+
# val = pickle.dumps(val, protocol=self.pickle_protocol)
639+
val = self._safe_serialize(val)
623640
with env.begin(write=True) as txn:
624641
txn.put(key, val)
625642
done = True
@@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
664681
warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
665682
return super()._cachecheck(item_transformed)
666683
try:
667-
return pickle.loads(data)
684+
# return pickle.loads(data)
685+
return self._safe_deserialize(data)
668686
except Exception as err:
669687
raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err
670688

@@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
16501668
meta_hash_file = self.cache_dir / meta_hash_file_name
16511669
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
16521670
torch.save(
1653-
obj=self._meta_cache[meta_hash_file_name],
1671+
obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),
16541672
f=temp_hash_file,
16551673
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
16561674
pickle_protocol=self.pickle_protocol,
@@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
16701688
if meta_hash_file_name in self._meta_cache:
16711689
return self._meta_cache[meta_hash_file_name]
16721690
else:
1673-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1691+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)

monai/data/meta_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,4 +611,4 @@ def print_verbose(self) -> None:
611611

612612
# needed in later versions of Pytorch to indicate the class is safe for serialisation
613613
if hasattr(torch.serialization, "add_safe_globals"):
614-
torch.serialization.add_safe_globals([MetaTensor])
614+
torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])

monai/data/utils.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch
3131
from torch.utils.data._utils.collate import default_collate
3232

33-
from monai import config
3433
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
3534
from monai.data.meta_obj import MetaObj
3635
from monai.utils import (
@@ -93,7 +92,6 @@
9392
"remove_keys",
9493
"remove_extra_metadata",
9594
"get_extra_metadata_keys",
96-
"PICKLE_KEY_SUFFIX",
9795
"is_no_channel",
9896
]
9997

@@ -418,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
418416
return
419417

420418

421-
PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX
422-
423-
424-
def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
425-
"""
426-
Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.
427-
428-
Args:
429-
data: a list or dictionary with substructures to be pickled/unpickled.
430-
key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
431-
is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
432-
"""
433-
if isinstance(data, Mapping):
434-
data = dict(data)
435-
for k in data:
436-
if f"{k}".endswith(key):
437-
if is_encode and not isinstance(data[k], bytes):
438-
data[k] = pickle.dumps(data[k], 0)
439-
if not is_encode and isinstance(data[k], bytes):
440-
data[k] = pickle.loads(data[k])
441-
return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
442-
elif isinstance(data, (list, tuple)):
443-
return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
444-
return data
445-
446-
447419
def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
448420
"""
449421
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
@@ -500,8 +472,8 @@ def list_data_collate(batch: Sequence):
500472
key = None
501473
collate_fn = default_collate
502474
try:
503-
if config.USE_META_DICT:
504-
data = pickle_operations(data) # bc 0.9.0
475+
# if config.USE_META_DICT:
476+
# data = pickle_operations(data) # bc 0.9.0
505477
if isinstance(elem, Mapping):
506478
ret = {}
507479
for k in elem:
@@ -654,15 +626,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
654626
if isinstance(deco, Mapping):
655627
_gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())
656628
ret = [dict(zip(deco, item)) for item in _gen]
657-
if not config.USE_META_DICT:
658-
return ret
659-
return pickle_operations(ret, is_encode=False) # bc 0.9.0
629+
# if not config.USE_META_DICT:
630+
# return ret
631+
# return pickle_operations(ret, is_encode=False) # bc 0.9.0
632+
return ret
660633
if isinstance(deco, Iterable):
661634
_gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)
662635
ret_list = [list(item) for item in _gen]
663-
if not config.USE_META_DICT:
664-
return ret_list
665-
return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
636+
# if not config.USE_META_DICT:
637+
# return ret_list
638+
# return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
639+
return ret_list
666640
raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")
667641

668642

monai/handlers/checkpoint_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
122122
Args:
123123
engine: Ignite Engine, it can be a trainer, validator or evaluator.
124124
"""
125-
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
125+
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)
126126

127127
k, _ = list(self.load_dict.items())[0]
128128
# single object and checkpoint is directly a state_dict

monai/utils/state_cacher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def __init__(
6464
pickle_module: module used for pickling metadata and objects, default to `pickle`.
6565
this arg is used by `torch.save`, for more details, please check:
6666
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
67-
pickle_protocol: can be specified to override the default protocol, default to `2`.
68-
this arg is used by `torch.save`, for more details, please check:
67+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
68+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
6969
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
7070
7171
"""
@@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
124124
fn = self.cached[key]["obj"] # pytype: disable=attribute-error
125125
if not os.path.exists(fn): # pytype: disable=wrong-arg-types
126126
raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
127-
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
127+
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
128128
# copy back to device if necessary
129129
if "device" in self.cached[key]:
130130
data_obj = data_obj.to(self.cached[key]["device"])

0 commit comments

Comments
 (0)