diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 2511ce2219..f07699594e 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -210,8 +210,12 @@ class PersistentDataset(Dataset): Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will be converted to tensors, however any other object type returned by transforms will not be loadable since - `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. - Legacy cache files may not be loadable and may need to be recomputed. + `torch.load` will be used with `weights_only=True` by default to prevent loading of potentially malicious + objects. Legacy cache files may not be loadable and may need to be recomputed. MetaTensor objects can be saved + and loaded with their metadata preserved if `track_meta` is True, however the objects stored in the metadata + must be acceptable as serialisable by `torch.load` by default or if they have been white-listed with + `torch.serialization.add_safe_globals`. Any other object type may be stored but will fail to load and force + a cache recompute. Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to @@ -245,8 +249,8 @@ def __init__( may share a common cache dir provided that the transforms pre-processing is consistent. If `cache_dir` doesn't exist, will automatically create it. If `cache_dir` is `None`, there is effectively no caching. - hash_func: a callable to compute hash from data items to be cached. - defaults to `monai.data.utils.pickle_hashing`. + hash_func: a callable to compute hash from data items to be cached, defaults to + `monai.data.utils.pickle_hashing` which uses sha256 (previously md5 so old caches will not work). pickle_module: string representing the module used for pickling metadata and objects, default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, we can't use `pickle` as arg directly, so here we use a string name instead. @@ -266,17 +270,12 @@ def __init__( When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances. - track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. - default to `False`. Cannot be used with `weights_only=True`. + track_meta: whether to track the meta information, defaults to False. If `True`, converts to `MetaTensor`. weights_only: keyword argument passed to `torch.load` when reading cached files. - default to `True`. When set to `True`, `torch.load` restricts loading to tensors and - other safe objects. Setting this to `False` is required for loading `MetaTensor` - objects saved with `track_meta=True`, however this creates the possibility of remote - code execution through `torch.load` so be aware of the security implications of doing so. - - Raises: - ValueError: When both `track_meta=True` and `weights_only=True`, since this combination - prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration. + default to `True`. When `True`, `torch.load` restricts loading to tensors and other safe objects. + Setting to `False` should only be done if it's absolutely necessary to load unsafe pickled data, + eg. MetaTensor objects with unsafe objects in their metadata. Users must verify the safety of the data + they intend to load before doing so. """ super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None @@ -292,11 +291,6 @@ def __init__( if hash_transform is not None: self.set_transform_hash(hash_transform) self.reset_ops_id = reset_ops_id - if track_meta and weights_only: - raise ValueError( - "Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. " - "To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`." - ) self.track_meta = track_meta self.weights_only = weights_only @@ -390,9 +384,9 @@ def _cachecheck(self, item_transformed): """ hashfile = None if self.cache_dir is not None: - data_item_md5 = self.hash_func(item_transformed).decode("utf-8") - data_item_md5 += self.transform_hash - hashfile = self.cache_dir / f"{data_item_md5}.pt" + data_item_hash = self.hash_func(item_transformed).decode("utf-8") + data_item_hash += self.transform_hash + hashfile = self.cache_dir / f"{data_item_hash}.pt" if hashfile is not None and hashfile.is_file(): # cache hit try: @@ -1624,9 +1618,9 @@ def _cachecheck(self, item_transformed): hashfile = None # compute a cache id if self.cache_dir is not None: - data_item_md5 = self.hash_func(item_transformed).decode("utf-8") - data_item_md5 += self.transform_hash - hashfile = self.cache_dir / f"{data_item_md5}.pt" + data_item_hash = self.hash_func(item_transformed).decode("utf-8") + data_item_hash += self.transform_hash + hashfile = self.cache_dir / f"{data_item_hash}.pt" if hashfile is not None and hashfile.is_file(): # cache hit with cp.cuda.Device(self.device): diff --git a/monai/data/utils.py b/monai/data/utils.py index b504ba9b60..64bd79c712 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -17,7 +17,6 @@ import math import os import pickle -import sys from collections import abc, defaultdict from collections.abc import Generator, Iterable, Mapping, Sequence, Sized from copy import deepcopy @@ -1370,13 +1369,8 @@ def json_hashing(item) -> bytes: """ # TODO: Find way to hash transforms content as part of the cache - cache_key = "" - if sys.version_info.minor < 9: - cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode("utf-8")).hexdigest() - else: - cache_key = hashlib.md5( - json.dumps(item, sort_keys=True).encode("utf-8"), usedforsecurity=False # type: ignore - ).hexdigest() + dump = json.dumps(item, sort_keys=True).encode("utf-8") + cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore return f"{cache_key}".encode() @@ -1391,13 +1385,8 @@ def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: Returns: the corresponding hash key """ - cache_key = "" - if sys.version_info.minor < 9: - cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest() - else: - cache_key = hashlib.md5( - pickle.dumps(sorted_dict(item), protocol=protocol), usedforsecurity=False # type: ignore - ).hexdigest() + dump = pickle.dumps(sorted_dict(item), protocol=protocol) + cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore return f"{cache_key}".encode() diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index ca62cdb184..c70519d98e 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -13,8 +13,11 @@ import contextlib import os +import pickle import tempfile import unittest +from pathlib import Path +from unittest.mock import patch import nibabel as nib import numpy as np @@ -46,7 +49,7 @@ TEST_CASE_4 = [True, False, False, MetaTensor] -TEST_CASE_5 = [True, True, True, None] +TEST_CASE_5 = [True, True, False, MetaTensor] TEST_CASE_6 = [False, False, False, torch.Tensor] @@ -200,6 +203,133 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er im = test_dataset[0]["image"] self.assertIsInstance(im, expected_type) + def test_metatensor_loading(self): + """ + Thorough test of metadata loading correctly with MetaTensor. This will store a MetaTensor with safe object types + in its metadata dictionary, test the cache file exists and can be safely loaded with weights only, and that the + loaded object is another MetaTensor with the correct information + """ + meta = {"test_meta": 123, "foo": "bar", "test_tuple": (1, 2, 3)} + imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4)) + + with tempfile.TemporaryDirectory() as tempdir: + cache_dir = Path(tempdir, "cache", "data") + + test_data = [{"image": imt}] + + test_dataset = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + im = test_dataset[0]["image"] + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + for k, v in meta.items(): + self.assertIn(k, im.meta, f"Metadata key {k} missing from loaded object.") + self.assertEqual(im.meta[k], v, f"Metadata key {k} not equal ({im.meta[k]}!={v}).") + + torch.testing.assert_close(imt.affine, im.affine) + + cache_files = list(cache_dir.glob("*")) + self.assertEqual(len(cache_files), 1, "Cached file not present.") + + cache_im = torch.load(cache_files[0], weights_only=True)["image"] + + self.assertIsInstance(cache_im, MetaTensor, "MetaTensor not stored in dataset.") + + for k, v in meta.items(): + self.assertIn(k, cache_im.meta, f"Metadata key {k} missing from loaded object.") + self.assertEqual(cache_im.meta[k], v, f"Metadata key {k} not equal ({cache_im.meta[k]}!={v}).") + + # create a new dataset to be sure + test_dataset2 = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # Replace torch.load with a function returning the same thing wrapped in a tuple, this is used to indicate + # the dataset loaded the cached data rather than recomputed. + old_load = torch.load + + def _mock_load(f, weights_only): + self.assertTrue(weights_only, f"torch.load called with {weights_only=}.") + return (old_load(f, weights_only=weights_only),) + + # check the returned object is a tuple containing the expected dict, if not then _mock_load wasn't called + with patch("torch.load", _mock_load): + im2_t = test_dataset2[0] + self.assertIsInstance(im2_t, tuple, "Special tuple not returned, so mock not used.") + self.assertIsInstance(im2_t[0]["image"], MetaTensor, "MetaTensor not stored in dataset.") + + def test_metatensor_badcache(self): + """ + Test attempting to save then load a MetaTensor with an unsafe metadata item raises an exception. This creates + a MetaTensor with an object in its metadata using unsafe code in __reduce__ which gets stored in the pickle. + When attempting to load this through torch.load, pickle.UnpicklingError should be raised to force a recompute + of the cached data rather than attempting to load something unsafe. + """ + with tempfile.TemporaryDirectory() as tempdir: + cache_dir = Path(tempdir) / "cache" / "data" + + class _BadType: + def __reduce__(self): + # something more insecure than this could be done with os.system + return (os.system, (f'echo "Code injected!" > {Path(tempdir)/"out.txt"!s}',)) + + meta = {"test_meta": 123, "foo": "bar", "bad_item": _BadType()} + imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4)) + test_data = [{"image": imt}] + + test_dataset = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # This will trigger the _BadType class code injection because deepcopy will use __reduce__, but will still + # write the cache file as needed for the test. The alternative was to write the cache file directly with a + # computed hash value, but computing that hash without using pickle_hashing isn't trivial. + im = test_dataset[0]["image"] + + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + cache_files = list(cache_dir.glob("*")) + self.assertEqual(len(cache_files), 1, "Cached file not present.") + + # loading the cache file directly will raise the pickle exception as expected + with self.assertRaises(pickle.UnpicklingError): + torch.load(cache_files[0], weights_only=True) + + # create a new dataset object just to be sure. When loading, a cache hit will occur but this will raise + # the pickle exception again and force a recompute of the cached data as well as a warning, this indicates + # the unsafe data was correctly rejected. + test_dataset2 = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # warning raised about recomputing the corrupted cache file which raised UnpicklingError + with self.assertWarns(UserWarning): + im = test_dataset2[0]["image"] + + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + cache_files2 = list(cache_dir.glob("*")) + + self.assertEqual(cache_files[0], cache_files2[0], "Hashes for cached data differ.") + if __name__ == "__main__": unittest.main()