Skip to content

Commit 2cb1efc

Browse files
authored
2432 Add RandCropByLabelClasses transform (#2557)
* [DLMED] add map_classes_to_indices Signed-off-by: Nic Ma <[email protected]> * [DLMED] add tests for empty classes Signed-off-by: Nic Ma <[email protected]> * [DLMED] add all the transforms Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix typo Signed-off-by: Nic Ma <[email protected]> * [DLMED] add tests for ClassesToIndices transforms Signed-off-by: Nic Ma <[email protected]> * [DLMED] add unit tests and documents Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix flake8 and add inverse tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] optimize random choice Signed-off-by: Nic Ma <[email protected]> * [DLMED] add check for num_samples Signed-off-by: Nic Ma <[email protected]> * [DLMED] add default value to ratios Signed-off-by: Nic Ma <[email protected]>
1 parent cbc8941 commit 2cb1efc

14 files changed

+1101
-37
lines changed

docs/source/transforms.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ Crop and Pad
120120
:members:
121121
:special-members: __call__
122122

123+
`RandCropByLabelClasses`
124+
""""""""""""""""""""""""
125+
.. autoclass:: RandCropByLabelClasses
126+
:members:
127+
:special-members: __call__
128+
123129
`ResizeWithPadOrCrop`
124130
"""""""""""""""""""""
125131
.. autoclass:: ResizeWithPadOrCrop
@@ -604,6 +610,12 @@ Utility
604610
:members:
605611
:special-members: __call__
606612

613+
`ClassesToIndices`
614+
""""""""""""""""""
615+
.. autoclass:: ClassesToIndices
616+
:members:
617+
:special-members: __call__
618+
607619
`ConvertToMultiChannelBasedOnBratsClasses`
608620
""""""""""""""""""""""""""""""""""""""""""
609621
.. autoclass:: ConvertToMultiChannelBasedOnBratsClasses
@@ -700,6 +712,12 @@ Crop and Pad (Dict)
700712
:members:
701713
:special-members: __call__
702714

715+
`RandCropByLabelClassesd`
716+
"""""""""""""""""""""""""
717+
.. autoclass:: RandCropByLabelClassesd
718+
:members:
719+
:special-members: __call__
720+
703721
`ResizeWithPadOrCropd`
704722
""""""""""""""""""""""
705723
.. autoclass:: ResizeWithPadOrCropd
@@ -1183,6 +1201,12 @@ Utility (Dict)
11831201
:members:
11841202
:special-members: __call__
11851203

1204+
`ClassesToIndicesd`
1205+
"""""""""""""""""""
1206+
.. autoclass:: ClassesToIndicesd
1207+
:members:
1208+
:special-members: __call__
1209+
11861210
`ConvertToMultiChannelBasedOnBratsClassesd`
11871211
"""""""""""""""""""""""""""""""""""""""""""
11881212
.. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd

monai/transforms/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
CenterSpatialCrop,
1919
CropForeground,
2020
DivisiblePad,
21+
RandCropByLabelClasses,
2122
RandCropByPosNegLabel,
2223
RandScaleCrop,
2324
RandSpatialCrop,
@@ -48,6 +49,9 @@
4849
DivisiblePadD,
4950
DivisiblePadDict,
5051
NumpyPadModeSequence,
52+
RandCropByLabelClassesd,
53+
RandCropByLabelClassesD,
54+
RandCropByLabelClassesDict,
5155
RandCropByPosNegLabeld,
5256
RandCropByPosNegLabelD,
5357
RandCropByPosNegLabelDict,
@@ -305,6 +309,7 @@
305309
AsChannelFirst,
306310
AsChannelLast,
307311
CastToType,
312+
ClassesToIndices,
308313
ConvertToMultiChannelBasedOnBratsClasses,
309314
DataStats,
310315
EnsureChannelFirst,
@@ -342,6 +347,9 @@
342347
CastToTyped,
343348
CastToTypeD,
344349
CastToTypeDict,
350+
ClassesToIndicesd,
351+
ClassesToIndicesD,
352+
ClassesToIndicesDict,
345353
ConcatItemsd,
346354
ConcatItemsD,
347355
ConcatItemsDict,
@@ -435,6 +443,7 @@
435443
create_shear,
436444
create_translate,
437445
extreme_points_to_image,
446+
generate_label_classes_crop_centers,
438447
generate_pos_neg_label_crop_centers,
439448
generate_spatial_bounding_box,
440449
get_extreme_points,
@@ -444,6 +453,7 @@
444453
is_empty,
445454
is_positive,
446455
map_binary_to_indices,
456+
map_classes_to_indices,
447457
map_spatial_axes,
448458
rand_choice,
449459
rescale_array,

monai/transforms/croppad/array.py

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
from monai.transforms.transform import Randomizable, Transform
2626
from monai.transforms.utils import (
2727
compute_divisible_spatial_size,
28+
generate_label_classes_crop_centers,
2829
generate_pos_neg_label_crop_centers,
2930
generate_spatial_bounding_box,
3031
is_positive,
3132
map_binary_to_indices,
33+
map_classes_to_indices,
3234
weighted_patch_samples,
3335
)
3436
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
@@ -46,6 +48,7 @@
4648
"CropForeground",
4749
"RandWeightedCrop",
4850
"RandCropByPosNegLabel",
51+
"RandCropByLabelClasses",
4952
"ResizeWithPadOrCrop",
5053
"BoundingRect",
5154
]
@@ -766,7 +769,11 @@ def randomize(
766769
) -> None:
767770
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
768771
if fg_indices is None or bg_indices is None:
769-
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
772+
if self.fg_indices is not None and self.bg_indices is not None:
773+
fg_indices_ = self.fg_indices
774+
bg_indices_ = self.bg_indices
775+
else:
776+
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
770777
else:
771778
fg_indices_ = fg_indices
772779
bg_indices_ = bg_indices
@@ -802,12 +809,7 @@ def __call__(
802809
raise ValueError("label should be provided.")
803810
if image is None:
804811
image = self.image
805-
if fg_indices is None or bg_indices is None:
806-
if self.fg_indices is not None and self.bg_indices is not None:
807-
fg_indices = self.fg_indices
808-
bg_indices = self.bg_indices
809-
else:
810-
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)
812+
811813
self.randomize(label, fg_indices, bg_indices, image)
812814
results: List[np.ndarray] = []
813815
if self.centers is not None:
@@ -818,6 +820,139 @@ def __call__(
818820
return results
819821

820822

823+
class RandCropByLabelClasses(Randomizable, Transform):
824+
"""
825+
Crop random fixed sized regions with the center being a class based on the specified ratios of every class.
826+
The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the
827+
cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`::
828+
829+
image = np.array([
830+
[[0.0, 0.3, 0.4, 0.2, 0.0],
831+
[0.0, 0.1, 0.2, 0.1, 0.4],
832+
[0.0, 0.3, 0.5, 0.2, 0.0],
833+
[0.1, 0.2, 0.1, 0.1, 0.0],
834+
[0.0, 0.1, 0.2, 0.1, 0.0]]
835+
])
836+
label = np.array([
837+
[[0, 0, 0, 0, 0],
838+
[0, 1, 2, 1, 0],
839+
[0, 1, 3, 0, 0],
840+
[0, 0, 0, 0, 0],
841+
[0, 0, 0, 0, 0]]
842+
])
843+
cropper = RandCropByLabelClasses(
844+
spatial_size=[3, 3],
845+
ratios=[1, 2, 3, 1],
846+
num_classes=4,
847+
num_samples=2,
848+
)
849+
label_samples = cropper(img=label, label=label, image=image)
850+
851+
The 2 randomly cropped samples of `label` can be:
852+
[[0, 1, 2], [[0, 0, 0],
853+
[0, 1, 3], [1, 2, 1],
854+
[0, 0, 0]] [1, 3, 0]]
855+
856+
If a dimension of the expected spatial size is bigger than the input image size,
857+
will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped
858+
results of several images may not have exactly same shape.
859+
860+
Args:
861+
spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
862+
if a dimension of ROI size is bigger than image size, will not crop that dimension of the image.
863+
if its components have non-positive values, the corresponding size of `label` will be used.
864+
for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,
865+
the spatial size of output data will be [32, 40, 40].
866+
ratios: specified ratios of every class in the label to generate crop centers, including background class.
867+
if None, every class will have the same ratio to generate crop centers.
868+
label: the label image that is used for finding every classes, if None, must set at `self.__call__`.
869+
num_classes: number of classes for argmax label, not necessary for One-Hot label.
870+
num_samples: number of samples (crop regions) to take in each list.
871+
image: if image is not None, only return the indices of every class that are within the valid
872+
region of the image (``image > image_threshold``).
873+
image_threshold: if enabled `image`, use ``image > image_threshold`` to
874+
determine the valid image content area and select class indices only in this area.
875+
indices: if provided pre-computed indices of every class, will ignore above `image` and
876+
`image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array
877+
of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first
878+
and cache the results for better performance.
879+
880+
"""
881+
882+
def __init__(
883+
self,
884+
spatial_size: Union[Sequence[int], int],
885+
ratios: Optional[List[Union[float, int]]] = None,
886+
label: Optional[np.ndarray] = None,
887+
num_classes: Optional[int] = None,
888+
num_samples: int = 1,
889+
image: Optional[np.ndarray] = None,
890+
image_threshold: float = 0.0,
891+
indices: Optional[List[np.ndarray]] = None,
892+
) -> None:
893+
self.spatial_size = ensure_tuple(spatial_size)
894+
self.ratios = ratios
895+
self.label = label
896+
self.num_classes = num_classes
897+
self.num_samples = num_samples
898+
self.image = image
899+
self.image_threshold = image_threshold
900+
self.centers: Optional[List[List[np.ndarray]]] = None
901+
self.indices = indices
902+
903+
def randomize(
904+
self,
905+
label: np.ndarray,
906+
indices: Optional[List[np.ndarray]] = None,
907+
image: Optional[np.ndarray] = None,
908+
) -> None:
909+
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
910+
indices_: List[np.ndarray]
911+
if indices is None:
912+
if self.indices is not None:
913+
indices_ = self.indices
914+
else:
915+
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
916+
else:
917+
indices_ = indices
918+
self.centers = generate_label_classes_crop_centers(
919+
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
920+
)
921+
922+
def __call__(
923+
self,
924+
img: np.ndarray,
925+
label: Optional[np.ndarray] = None,
926+
image: Optional[np.ndarray] = None,
927+
indices: Optional[List[np.ndarray]] = None,
928+
) -> List[np.ndarray]:
929+
"""
930+
Args:
931+
img: input data to crop samples from based on the ratios of every class, assumes `img` is a
932+
channel-first array.
933+
label: the label image that is used for finding indices of every class, if None, use `self.label`.
934+
image: optional image data to help select valid area, can be same as `img` or another image array.
935+
use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`.
936+
indices: list of indices for every class in the image, used to randomly select crop centers.
937+
938+
"""
939+
if label is None:
940+
label = self.label
941+
if label is None:
942+
raise ValueError("label should be provided.")
943+
if image is None:
944+
image = self.image
945+
946+
self.randomize(label, indices, image)
947+
results: List[np.ndarray] = []
948+
if self.centers is not None:
949+
for center in self.centers:
950+
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
951+
results.append(cropper(img))
952+
953+
return results
954+
955+
821956
class ResizeWithPadOrCrop(Transform):
822957
"""
823958
Resize an image to a target spatial size by either centrally cropping the image or

0 commit comments

Comments
 (0)