|
20 | 20 | import torch
|
21 | 21 |
|
22 | 22 | from monai import transforms
|
| 23 | +from monai.data.meta_obj import MetaObj, get_track_meta |
23 | 24 | from monai.data.meta_tensor import MetaTensor
|
24 |
| -from monai.transforms.transform import Transform |
25 |
| -from monai.utils.enums import TraceKeys |
| 25 | +from monai.data.utils import to_affine_nd |
| 26 | +from monai.transforms.transform import LazyTransform, Transform |
| 27 | +from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor |
26 | 28 |
|
27 | 29 | __all__ = ["TraceableTransform", "InvertibleTransform"]
|
28 | 30 |
|
@@ -72,76 +74,160 @@ def trace_key(key: Hashable = None):
|
72 | 74 | return f"{TraceKeys.KEY_SUFFIX}"
|
73 | 75 | return f"{key}{TraceKeys.KEY_SUFFIX}"
|
74 | 76 |
|
75 |
| - def get_transform_info( |
76 |
| - self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None |
77 |
| - ) -> dict: |
| 77 | + @staticmethod |
| 78 | + def transform_info_keys(): |
| 79 | + """The keys to store necessary info of an applied transform.""" |
| 80 | + return ( |
| 81 | + TraceKeys.CLASS_NAME, |
| 82 | + TraceKeys.ID, |
| 83 | + TraceKeys.TRACING, |
| 84 | + TraceKeys.LAZY_EVALUATION, |
| 85 | + TraceKeys.DO_TRANSFORM, |
| 86 | + ) |
| 87 | + |
| 88 | + def get_transform_info(self) -> dict: |
78 | 89 | """
|
79 | 90 | Return a dictionary with the relevant information pertaining to an applied transform.
|
| 91 | + """ |
| 92 | + vals = ( |
| 93 | + self.__class__.__name__, |
| 94 | + id(self), |
| 95 | + self.tracing, |
| 96 | + self.lazy_evaluation if isinstance(self, LazyTransform) else False, |
| 97 | + self._do_transform if hasattr(self, "_do_transform") else True, |
| 98 | + ) |
| 99 | + return dict(zip(self.transform_info_keys(), vals)) |
80 | 100 |
|
81 |
| - Args: |
82 |
| - data: input data. Can be dictionary or MetaTensor. We can use `shape` to |
83 |
| - determine the original size of the object (unless that has been given |
84 |
| - explicitly, see `orig_size`). |
85 |
| - key: if data is a dictionary, data[key] will be modified. |
86 |
| - extra_info: if desired, any extra information pertaining to the applied |
87 |
| - transform can be stored in this dictionary. These are often needed for |
88 |
| - computing the inverse transformation. |
89 |
| - orig_size: sometimes during the inverse it is useful to know what the size |
90 |
| - of the original image was, in which case it can be supplied here. |
| 101 | + def push_transform(self, data, *args, **kwargs): |
| 102 | + """ |
| 103 | + Push to a stack of applied transforms of ``data``. |
91 | 104 |
|
92 |
| - Returns: |
93 |
| - Dictionary of data pertaining to the applied transformation. |
| 105 | + Args: |
| 106 | + data: dictionary of data or `MetaTensor`. |
| 107 | + args: additional positional arguments to track_transform_meta. |
| 108 | + kwargs: additional keyword arguments to track_transform_meta, |
| 109 | + set ``replace=True`` (default False) to rewrite the last transform infor in |
| 110 | + applied_operation/pending_operation based on ``self.get_transform_info()``. |
94 | 111 | """
|
95 |
| - info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} |
96 |
| - if orig_size is not None: |
97 |
| - info[TraceKeys.ORIG_SIZE] = orig_size |
98 |
| - elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): |
99 |
| - info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] |
100 |
| - elif hasattr(data, "shape"): |
101 |
| - info[TraceKeys.ORIG_SIZE] = data.shape[1:] |
102 |
| - if extra_info is not None: |
103 |
| - info[TraceKeys.EXTRA_INFO] = extra_info |
104 |
| - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) |
105 |
| - if hasattr(self, "_do_transform"): # RandomizableTransform |
106 |
| - info[TraceKeys.DO_TRANSFORM] = self._do_transform |
107 |
| - return info |
108 |
| - |
109 |
| - def push_transform( |
110 |
| - self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None |
111 |
| - ) -> None: |
| 112 | + transform_info = self.get_transform_info() |
| 113 | + lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) |
| 114 | + do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) |
| 115 | + kwargs = kwargs or {} |
| 116 | + replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info |
| 117 | + if replace and get_track_meta() and isinstance(data, MetaTensor): |
| 118 | + if not lazy_eval: |
| 119 | + xform = self.pop_transform(data, check=False) if do_transform else {} |
| 120 | + meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform) |
| 121 | + return data.copy_meta_from(meta_obj) |
| 122 | + if do_transform: |
| 123 | + xform = data.pending_operations.pop() # type: ignore |
| 124 | + xform.update(transform_info) |
| 125 | + meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval) |
| 126 | + return data.copy_meta_from(meta_obj) |
| 127 | + return data |
| 128 | + kwargs["lazy_evaluation"] = lazy_eval |
| 129 | + if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict): |
| 130 | + kwargs["transform_info"].update(transform_info) |
| 131 | + else: |
| 132 | + kwargs["transform_info"] = transform_info |
| 133 | + meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs) |
| 134 | + return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data |
| 135 | + |
| 136 | + @classmethod |
| 137 | + def track_transform_meta( |
| 138 | + cls, |
| 139 | + data, |
| 140 | + key: Hashable = None, |
| 141 | + sp_size=None, |
| 142 | + affine=None, |
| 143 | + extra_info: dict | None = None, |
| 144 | + orig_size: tuple | None = None, |
| 145 | + transform_info=None, |
| 146 | + lazy_evaluation=False, |
| 147 | + ): |
112 | 148 | """
|
113 |
| - Push to a stack of applied transforms. |
| 149 | + Update a stack of applied/pending transforms metadata of ``data``. |
114 | 150 |
|
115 | 151 | Args:
|
116 | 152 | data: dictionary of data or `MetaTensor`.
|
117 | 153 | key: if data is a dictionary, data[key] will be modified.
|
| 154 | + sp_size: the expected output spatial size when the transform is applied. |
| 155 | + it can be tensor or numpy, but will be converted to a list of integers. |
| 156 | + affine: the affine representation of the (spatial) transform in the image space. |
| 157 | + When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``. |
118 | 158 | extra_info: if desired, any extra information pertaining to the applied
|
119 | 159 | transform can be stored in this dictionary. These are often needed for
|
120 | 160 | computing the inverse transformation.
|
121 | 161 | orig_size: sometimes during the inverse it is useful to know what the size
|
122 | 162 | of the original image was, in which case it can be supplied here.
|
| 163 | + transform_info: info from self.get_transform_info(). |
| 164 | + lazy_evaluation: whether to push the transform to pending_operations or applied_operations. |
123 | 165 |
|
124 | 166 | Returns:
|
125 |
| - None, but data has been updated to store the applied transformation. |
| 167 | +
|
| 168 | + For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with |
| 169 | + updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata. |
126 | 170 | """
|
127 |
| - if not self.tracing: |
128 |
| - return |
129 |
| - info = self.get_transform_info(data, key, extra_info, orig_size) |
| 171 | + data_t = data[key] if key is not None else data # compatible with the dict data representation |
| 172 | + out_obj = MetaObj() |
| 173 | + # after deprecating metadict, we should always convert data_t to metatensor here |
| 174 | + if isinstance(data_t, MetaTensor): |
| 175 | + out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) |
| 176 | + |
| 177 | + if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): |
| 178 | + # not lazy evaluation, directly update the metatensor affine (don't push to the stack) |
| 179 | + orig_affine = data_t.peek_pending_affine() |
| 180 | + orig_affine = convert_to_dst_type(orig_affine, affine)[0] |
| 181 | + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype) |
| 182 | + out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) |
| 183 | + |
| 184 | + if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): |
| 185 | + if isinstance(data, Mapping): |
| 186 | + if not isinstance(data, dict): |
| 187 | + data = dict(data) |
| 188 | + data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t |
| 189 | + return data |
| 190 | + return out_obj # return with data_t as tensor if get_track_meta() is False |
| 191 | + |
| 192 | + info = transform_info |
| 193 | + # track the current spatial shape |
| 194 | + if orig_size is not None: |
| 195 | + info[TraceKeys.ORIG_SIZE] = orig_size |
| 196 | + elif isinstance(data_t, MetaTensor): |
| 197 | + info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() |
| 198 | + elif hasattr(data_t, "shape"): |
| 199 | + info[TraceKeys.ORIG_SIZE] = data_t.shape[1:] |
| 200 | + # include extra_info |
| 201 | + if extra_info is not None: |
| 202 | + info[TraceKeys.EXTRA_INFO] = extra_info |
130 | 203 |
|
131 |
| - if isinstance(data, MetaTensor): |
132 |
| - data.push_applied_operation(info) |
133 |
| - elif isinstance(data, Mapping): |
134 |
| - if key in data and isinstance(data[key], MetaTensor): |
135 |
| - data[key].push_applied_operation(info) |
| 204 | + # push the transform info to the applied_operation or pending_operation stack |
| 205 | + if lazy_evaluation: |
| 206 | + if sp_size is None: |
| 207 | + if LazyAttr.SHAPE not in info: |
| 208 | + warnings.warn("spatial size is None in push transform.") |
| 209 | + else: |
| 210 | + info[LazyAttr.SHAPE] = tuple(convert_to_numpy(sp_size, wrap_sequence=True).tolist()) |
| 211 | + if affine is None: |
| 212 | + if LazyAttr.AFFINE not in info: |
| 213 | + warnings.warn("affine is None in push transform.") |
136 | 214 | else:
|
137 |
| - # If this is the first, create list |
138 |
| - if self.trace_key(key) not in data: |
139 |
| - if not isinstance(data, dict): |
140 |
| - data = dict(data) |
141 |
| - data[self.trace_key(key)] = [] |
142 |
| - data[self.trace_key(key)].append(info) |
| 215 | + info[LazyAttr.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) |
| 216 | + out_obj.push_pending_operation(info) |
143 | 217 | else:
|
144 |
| - warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") |
| 218 | + out_obj.push_applied_operation(info) |
| 219 | + if isinstance(data, Mapping): |
| 220 | + if not isinstance(data, dict): |
| 221 | + data = dict(data) |
| 222 | + if isinstance(data_t, MetaTensor): |
| 223 | + data[key] = data_t.copy_meta_from(out_obj) |
| 224 | + else: |
| 225 | + x_k = TraceableTransform.trace_key(key) |
| 226 | + if x_k not in data: |
| 227 | + data[x_k] = [] # If this is the first, create list |
| 228 | + data[x_k].append(info) |
| 229 | + return data |
| 230 | + return out_obj |
145 | 231 |
|
146 | 232 | def check_transforms_match(self, transform: Mapping) -> None:
|
147 | 233 | """Check transforms are of same instance."""
|
|
0 commit comments