13
13
14
14
import collections .abc
15
15
import math
16
- import pickle
17
16
import shutil
18
17
import sys
19
18
import tempfile
22
21
import warnings
23
22
from collections .abc import Callable , Sequence
24
23
from copy import copy , deepcopy
24
+ from io import BytesIO
25
25
from multiprocessing .managers import ListProxy
26
26
from multiprocessing .pool import ThreadPool
27
27
from pathlib import Path
28
+ from pickle import UnpicklingError
28
29
from typing import IO , TYPE_CHECKING , Any , cast
29
30
30
31
import numpy as np
@@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
207
208
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
208
209
errors. If in doubt, it is advisable to clear the cache directory.
209
210
211
+ Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
212
+ be converted to tensors, however any other object type returned by transforms will not be loadable since
213
+ `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
214
+ Legacy cache files may not be loadable and may need to be recomputed.
215
+
210
216
Lazy Resampling:
211
217
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
212
218
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
@@ -248,8 +254,8 @@ def __init__(
248
254
this arg is used by `torch.save`, for more details, please check:
249
255
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
250
256
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
251
- pickle_protocol: can be specified to override the default protocol, default to `2 `.
252
- this arg is used by ` torch.save`, for more details, please check:
257
+ pickle_protocol: specifies pickle protocol when saving, with `torch.save `.
258
+ Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
253
259
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
254
260
hash_transform: a callable to compute hash from the transform information when caching.
255
261
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):
371
377
372
378
if hashfile is not None and hashfile .is_file (): # cache hit
373
379
try :
374
- return torch .load (hashfile , weights_only = False )
380
+ return torch .load (hashfile , weights_only = True )
375
381
except PermissionError as e :
376
382
if sys .platform != "win32" :
377
383
raise e
378
- except RuntimeError as e :
379
- if "Invalid magic number; corrupt file" in str (e ):
384
+ except ( UnpicklingError , RuntimeError ) as e : # corrupt or unloadable cached files are recomputed
385
+ if "Invalid magic number; corrupt file" in str (e ) or isinstance ( e , UnpicklingError ) :
380
386
warnings .warn (f"Corrupt cache file detected: { hashfile } . Deleting and recomputing." )
381
387
hashfile .unlink ()
382
388
else :
@@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
392
398
with tempfile .TemporaryDirectory () as tmpdirname :
393
399
temp_hash_file = Path (tmpdirname ) / hashfile .name
394
400
torch .save (
395
- obj = _item_transformed ,
401
+ obj = convert_to_tensor ( _item_transformed , convert_numeric = False ) ,
396
402
f = temp_hash_file ,
397
403
pickle_module = look_up_option (self .pickle_module , SUPPORTED_PICKLE_MOD ),
398
404
pickle_protocol = self .pickle_protocol ,
@@ -455,8 +461,8 @@ def __init__(
455
461
this arg is used by `torch.save`, for more details, please check:
456
462
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
457
463
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
458
- pickle_protocol: can be specified to override the default protocol, default to `2 `.
459
- this arg is used by ` torch.save`, for more details, please check:
464
+ pickle_protocol: specifies pickle protocol when saving, with `torch.save `.
465
+ Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
460
466
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
461
467
hash_transform: a callable to compute hash from the transform information when caching.
462
468
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -531,7 +537,7 @@ def __init__(
531
537
hash_func : Callable [..., bytes ] = pickle_hashing ,
532
538
db_name : str = "monai_cache" ,
533
539
progress : bool = True ,
534
- pickle_protocol = pickle . HIGHEST_PROTOCOL ,
540
+ pickle_protocol = DEFAULT_PROTOCOL ,
535
541
hash_transform : Callable [..., bytes ] | None = None ,
536
542
reset_ops_id : bool = True ,
537
543
lmdb_kwargs : dict | None = None ,
@@ -551,8 +557,9 @@ def __init__(
551
557
defaults to `monai.data.utils.pickle_hashing`.
552
558
db_name: lmdb database file name. Defaults to "monai_cache".
553
559
progress: whether to display a progress bar.
554
- pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
555
- https://docs.python.org/3/library/pickle.html#pickle-protocols
560
+ pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
561
+ Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
562
+ https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
556
563
hash_transform: a callable to compute hash from the transform information when caching.
557
564
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
558
565
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
@@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
594
601
super ().set_data (data = data )
595
602
self ._read_env = self ._fill_cache_start_reader (show_progress = self .progress )
596
603
604
+ def _safe_serialize (self , val ):
605
+ out = BytesIO ()
606
+ torch .save (convert_to_tensor (val ), out , pickle_protocol = self .pickle_protocol )
607
+ out .seek (0 )
608
+ return out .read ()
609
+
610
+ def _safe_deserialize (self , val ):
611
+ return torch .load (BytesIO (val ), map_location = "cpu" , weights_only = True )
612
+
597
613
def _fill_cache_start_reader (self , show_progress = True ):
598
614
"""
599
615
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
@@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
619
635
continue
620
636
if val is None :
621
637
val = self ._pre_transform (deepcopy (item )) # keep the original hashed
622
- val = pickle .dumps (val , protocol = self .pickle_protocol )
638
+ # val = pickle.dumps(val, protocol=self.pickle_protocol)
639
+ val = self ._safe_serialize (val )
623
640
with env .begin (write = True ) as txn :
624
641
txn .put (key , val )
625
642
done = True
@@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
664
681
warnings .warn ("LMDBDataset: cache key not found, running fallback caching." )
665
682
return super ()._cachecheck (item_transformed )
666
683
try :
667
- return pickle .loads (data )
684
+ # return pickle.loads(data)
685
+ return self ._safe_deserialize (data )
668
686
except Exception as err :
669
687
raise RuntimeError ("Invalid cache value, corrupted lmdb file?" ) from err
670
688
@@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
1650
1668
meta_hash_file = self .cache_dir / meta_hash_file_name
1651
1669
temp_hash_file = Path (tmpdirname ) / meta_hash_file_name
1652
1670
torch .save (
1653
- obj = self ._meta_cache [meta_hash_file_name ],
1671
+ obj = convert_to_tensor ( self ._meta_cache [meta_hash_file_name ], convert_numeric = False ) ,
1654
1672
f = temp_hash_file ,
1655
1673
pickle_module = look_up_option (self .pickle_module , SUPPORTED_PICKLE_MOD ),
1656
1674
pickle_protocol = self .pickle_protocol ,
@@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
1670
1688
if meta_hash_file_name in self ._meta_cache :
1671
1689
return self ._meta_cache [meta_hash_file_name ]
1672
1690
else :
1673
- return torch .load (self .cache_dir / meta_hash_file_name , weights_only = False )
1691
+ return torch .load (self .cache_dir / meta_hash_file_name , weights_only = True )
0 commit comments