Skip to content

Commit b592164

Browse files
authored
4855 5860 update the pending transform utilities (#5916)
one of the sub-PRs from #5860 also fixes #5509, by reviewing the metetensor copying in loadimage ### Description This PR mainly to enhance the `apply_transforms`, `resample`, `push_transform` APIs to get prepared for lazy resampling. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <[email protected]>
1 parent 968878a commit b592164

File tree

13 files changed

+251
-80
lines changed

13 files changed

+251
-80
lines changed

monai/data/image_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,7 @@ def _get_array_data(self, img):
10311031
img: a Nibabel image object loaded from an image file.
10321032
10331033
"""
1034-
return np.asanyarray(img.dataobj)
1034+
return np.asanyarray(img.dataobj, order="C")
10351035

10361036

10371037
class NumpyReader(ImageReader):

monai/data/meta_obj.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import numpy as np
2020
import torch
2121

22-
from monai.utils.enums import TraceKeys
23-
from monai.utils.misc import first
22+
from monai.utils import TraceKeys, first, is_immutable
2423

2524
_TRACK_META = True
2625

@@ -107,27 +106,35 @@ def flatten_meta_objs(*args: Iterable):
107106
@staticmethod
108107
def copy_items(data):
109108
"""returns a copy of the data. list and dict are shallow copied for efficiency purposes."""
109+
if is_immutable(data):
110+
return data
110111
if isinstance(data, (list, dict, np.ndarray)):
111112
return data.copy()
112113
if isinstance(data, torch.Tensor):
113114
return data.detach().clone()
114115
return deepcopy(data)
115116

116-
def copy_meta_from(self, input_objs, copy_attr=True) -> None:
117+
def copy_meta_from(self, input_objs, copy_attr=True, keys=None):
117118
"""
118119
Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances.
119120
120121
Args:
121122
input_objs: list of `MetaObj` to copy data from.
122123
copy_attr: whether to copy each attribute with `MetaObj.copy_item`.
123124
note that if the attribute is a nested list or dict, only a shallow copy will be done.
125+
keys: the keys of attributes to copy from the ``input_objs``.
126+
If None, all keys from the input_objs will be copied.
124127
"""
125128
first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self)
129+
if not hasattr(first_meta, "__dict__"):
130+
return self
126131
first_meta = first_meta.__dict__
132+
keys = first_meta.keys() if keys is None else keys
127133
if not copy_attr:
128-
self.__dict__ = first_meta.copy() # shallow copy for performance
134+
self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta} # shallow copy for performance
129135
else:
130-
self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in first_meta})
136+
self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta})
137+
return self
131138

132139
@staticmethod
133140
def get_default_meta() -> dict:

monai/data/meta_tensor.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -503,15 +503,15 @@ def clone(self):
503503

504504
@staticmethod
505505
def ensure_torch_and_prune_meta(
506-
im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
506+
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
507507
):
508508
"""
509-
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
509+
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
510510
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
511511
512512
Args:
513513
im: Input image (`np.ndarray` or `torch.Tensor`)
514-
meta: Metadata dictionary.
514+
meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor.
515515
simple_keys: whether to keep only a simple subset of metadata keys.
516516
pattern: combined with `sep`, a regular expression used to match and prune keys
517517
in the metadata (nested dictionary), default to None, no key deletion.
@@ -521,14 +521,17 @@ def ensure_torch_and_prune_meta(
521521
522522
Returns:
523523
By default, a `MetaTensor` is returned.
524-
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
524+
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
525525
"""
526-
img = convert_to_tensor(im) # potentially ascontiguousarray
526+
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray
527527

528528
# if not tracking metadata, return `torch.Tensor`
529-
if not get_track_meta() or meta is None:
529+
if not isinstance(img, MetaTensor):
530530
return img
531531

532+
if meta is None:
533+
meta = {}
534+
532535
# remove any superfluous metadata.
533536
if simple_keys:
534537
# ensure affine is of type `torch.Tensor`
@@ -540,7 +543,14 @@ def ensure_torch_and_prune_meta(
540543
meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)
541544

542545
# return the `MetaTensor`
543-
return MetaTensor(img, meta=meta)
546+
if meta is None:
547+
meta = {}
548+
img.meta = meta
549+
if MetaKeys.AFFINE in meta:
550+
img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter
551+
else:
552+
img.affine = MetaTensor.get_default_affine()
553+
return img
544554

545555
def __repr__(self):
546556
"""

monai/data/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
ensure_tuple_size,
4747
fall_back_tuple,
4848
first,
49+
get_equivalent_dtype,
4950
issequenceiterable,
5051
look_up_option,
5152
optional_import,
@@ -924,6 +925,7 @@ def to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) -
924925
an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)
925926
926927
"""
928+
dtype = get_equivalent_dtype(dtype, np.ndarray)
927929
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]
928930
affine_np = affine_np.copy()
929931
if affine_np.ndim != 2:

monai/transforms/inverse.py

Lines changed: 137 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
import torch
2121

2222
from monai import transforms
23+
from monai.data.meta_obj import MetaObj, get_track_meta
2324
from monai.data.meta_tensor import MetaTensor
24-
from monai.transforms.transform import Transform
25-
from monai.utils.enums import TraceKeys
25+
from monai.data.utils import to_affine_nd
26+
from monai.transforms.transform import LazyTransform, Transform
27+
from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor
2628

2729
__all__ = ["TraceableTransform", "InvertibleTransform"]
2830

@@ -72,76 +74,160 @@ def trace_key(key: Hashable = None):
7274
return f"{TraceKeys.KEY_SUFFIX}"
7375
return f"{key}{TraceKeys.KEY_SUFFIX}"
7476

75-
def get_transform_info(
76-
self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None
77-
) -> dict:
77+
@staticmethod
78+
def transform_info_keys():
79+
"""The keys to store necessary info of an applied transform."""
80+
return (
81+
TraceKeys.CLASS_NAME,
82+
TraceKeys.ID,
83+
TraceKeys.TRACING,
84+
TraceKeys.LAZY_EVALUATION,
85+
TraceKeys.DO_TRANSFORM,
86+
)
87+
88+
def get_transform_info(self) -> dict:
7889
"""
7990
Return a dictionary with the relevant information pertaining to an applied transform.
91+
"""
92+
vals = (
93+
self.__class__.__name__,
94+
id(self),
95+
self.tracing,
96+
self.lazy_evaluation if isinstance(self, LazyTransform) else False,
97+
self._do_transform if hasattr(self, "_do_transform") else True,
98+
)
99+
return dict(zip(self.transform_info_keys(), vals))
80100

81-
Args:
82-
data: input data. Can be dictionary or MetaTensor. We can use `shape` to
83-
determine the original size of the object (unless that has been given
84-
explicitly, see `orig_size`).
85-
key: if data is a dictionary, data[key] will be modified.
86-
extra_info: if desired, any extra information pertaining to the applied
87-
transform can be stored in this dictionary. These are often needed for
88-
computing the inverse transformation.
89-
orig_size: sometimes during the inverse it is useful to know what the size
90-
of the original image was, in which case it can be supplied here.
101+
def push_transform(self, data, *args, **kwargs):
102+
"""
103+
Push to a stack of applied transforms of ``data``.
91104
92-
Returns:
93-
Dictionary of data pertaining to the applied transformation.
105+
Args:
106+
data: dictionary of data or `MetaTensor`.
107+
args: additional positional arguments to track_transform_meta.
108+
kwargs: additional keyword arguments to track_transform_meta,
109+
set ``replace=True`` (default False) to rewrite the last transform infor in
110+
applied_operation/pending_operation based on ``self.get_transform_info()``.
94111
"""
95-
info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)}
96-
if orig_size is not None:
97-
info[TraceKeys.ORIG_SIZE] = orig_size
98-
elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"):
99-
info[TraceKeys.ORIG_SIZE] = data[key].shape[1:]
100-
elif hasattr(data, "shape"):
101-
info[TraceKeys.ORIG_SIZE] = data.shape[1:]
102-
if extra_info is not None:
103-
info[TraceKeys.EXTRA_INFO] = extra_info
104-
# If class is randomizable transform, store whether the transform was actually performed (based on `prob`)
105-
if hasattr(self, "_do_transform"): # RandomizableTransform
106-
info[TraceKeys.DO_TRANSFORM] = self._do_transform
107-
return info
108-
109-
def push_transform(
110-
self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None
111-
) -> None:
112+
transform_info = self.get_transform_info()
113+
lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False)
114+
do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True)
115+
kwargs = kwargs or {}
116+
replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info
117+
if replace and get_track_meta() and isinstance(data, MetaTensor):
118+
if not lazy_eval:
119+
xform = self.pop_transform(data, check=False) if do_transform else {}
120+
meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform)
121+
return data.copy_meta_from(meta_obj)
122+
if do_transform:
123+
xform = data.pending_operations.pop() # type: ignore
124+
xform.update(transform_info)
125+
meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval)
126+
return data.copy_meta_from(meta_obj)
127+
return data
128+
kwargs["lazy_evaluation"] = lazy_eval
129+
if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict):
130+
kwargs["transform_info"].update(transform_info)
131+
else:
132+
kwargs["transform_info"] = transform_info
133+
meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs)
134+
return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data
135+
136+
@classmethod
137+
def track_transform_meta(
138+
cls,
139+
data,
140+
key: Hashable = None,
141+
sp_size=None,
142+
affine=None,
143+
extra_info: dict | None = None,
144+
orig_size: tuple | None = None,
145+
transform_info=None,
146+
lazy_evaluation=False,
147+
):
112148
"""
113-
Push to a stack of applied transforms.
149+
Update a stack of applied/pending transforms metadata of ``data``.
114150
115151
Args:
116152
data: dictionary of data or `MetaTensor`.
117153
key: if data is a dictionary, data[key] will be modified.
154+
sp_size: the expected output spatial size when the transform is applied.
155+
it can be tensor or numpy, but will be converted to a list of integers.
156+
affine: the affine representation of the (spatial) transform in the image space.
157+
When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``.
118158
extra_info: if desired, any extra information pertaining to the applied
119159
transform can be stored in this dictionary. These are often needed for
120160
computing the inverse transformation.
121161
orig_size: sometimes during the inverse it is useful to know what the size
122162
of the original image was, in which case it can be supplied here.
163+
transform_info: info from self.get_transform_info().
164+
lazy_evaluation: whether to push the transform to pending_operations or applied_operations.
123165
124166
Returns:
125-
None, but data has been updated to store the applied transformation.
167+
168+
For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with
169+
updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata.
126170
"""
127-
if not self.tracing:
128-
return
129-
info = self.get_transform_info(data, key, extra_info, orig_size)
171+
data_t = data[key] if key is not None else data # compatible with the dict data representation
172+
out_obj = MetaObj()
173+
# after deprecating metadict, we should always convert data_t to metatensor here
174+
if isinstance(data_t, MetaTensor):
175+
out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys())
176+
177+
if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor):
178+
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
179+
orig_affine = data_t.peek_pending_affine()
180+
orig_affine = convert_to_dst_type(orig_affine, affine)[0]
181+
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype)
182+
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"))
183+
184+
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
185+
if isinstance(data, Mapping):
186+
if not isinstance(data, dict):
187+
data = dict(data)
188+
data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t
189+
return data
190+
return out_obj # return with data_t as tensor if get_track_meta() is False
191+
192+
info = transform_info
193+
# track the current spatial shape
194+
if orig_size is not None:
195+
info[TraceKeys.ORIG_SIZE] = orig_size
196+
elif isinstance(data_t, MetaTensor):
197+
info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape()
198+
elif hasattr(data_t, "shape"):
199+
info[TraceKeys.ORIG_SIZE] = data_t.shape[1:]
200+
# include extra_info
201+
if extra_info is not None:
202+
info[TraceKeys.EXTRA_INFO] = extra_info
130203

131-
if isinstance(data, MetaTensor):
132-
data.push_applied_operation(info)
133-
elif isinstance(data, Mapping):
134-
if key in data and isinstance(data[key], MetaTensor):
135-
data[key].push_applied_operation(info)
204+
# push the transform info to the applied_operation or pending_operation stack
205+
if lazy_evaluation:
206+
if sp_size is None:
207+
if LazyAttr.SHAPE not in info:
208+
warnings.warn("spatial size is None in push transform.")
209+
else:
210+
info[LazyAttr.SHAPE] = tuple(convert_to_numpy(sp_size, wrap_sequence=True).tolist())
211+
if affine is None:
212+
if LazyAttr.AFFINE not in info:
213+
warnings.warn("affine is None in push transform.")
136214
else:
137-
# If this is the first, create list
138-
if self.trace_key(key) not in data:
139-
if not isinstance(data, dict):
140-
data = dict(data)
141-
data[self.trace_key(key)] = []
142-
data[self.trace_key(key)].append(info)
215+
info[LazyAttr.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"))
216+
out_obj.push_pending_operation(info)
143217
else:
144-
warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.")
218+
out_obj.push_applied_operation(info)
219+
if isinstance(data, Mapping):
220+
if not isinstance(data, dict):
221+
data = dict(data)
222+
if isinstance(data_t, MetaTensor):
223+
data[key] = data_t.copy_meta_from(out_obj)
224+
else:
225+
x_k = TraceableTransform.trace_key(key)
226+
if x_k not in data:
227+
data[x_k] = [] # If this is the first, create list
228+
data[x_k].append(info)
229+
return data
230+
return out_obj
145231

146232
def check_transforms_match(self, transform: Mapping) -> None:
147233
"""Check transforms are of same instance."""

0 commit comments

Comments
 (0)