25
25
from monai .transforms .transform import Randomizable , Transform
26
26
from monai .transforms .utils import (
27
27
compute_divisible_spatial_size ,
28
+ generate_label_classes_crop_centers ,
28
29
generate_pos_neg_label_crop_centers ,
29
30
generate_spatial_bounding_box ,
30
31
is_positive ,
31
32
map_binary_to_indices ,
33
+ map_classes_to_indices ,
32
34
weighted_patch_samples ,
33
35
)
34
36
from monai .utils import Method , NumpyPadMode , ensure_tuple , ensure_tuple_rep , fall_back_tuple
46
48
"CropForeground" ,
47
49
"RandWeightedCrop" ,
48
50
"RandCropByPosNegLabel" ,
51
+ "RandCropByLabelClasses" ,
49
52
"ResizeWithPadOrCrop" ,
50
53
"BoundingRect" ,
51
54
]
@@ -766,7 +769,11 @@ def randomize(
766
769
) -> None :
767
770
self .spatial_size = fall_back_tuple (self .spatial_size , default = label .shape [1 :])
768
771
if fg_indices is None or bg_indices is None :
769
- fg_indices_ , bg_indices_ = map_binary_to_indices (label , image , self .image_threshold )
772
+ if self .fg_indices is not None and self .bg_indices is not None :
773
+ fg_indices_ = self .fg_indices
774
+ bg_indices_ = self .bg_indices
775
+ else :
776
+ fg_indices_ , bg_indices_ = map_binary_to_indices (label , image , self .image_threshold )
770
777
else :
771
778
fg_indices_ = fg_indices
772
779
bg_indices_ = bg_indices
@@ -802,12 +809,7 @@ def __call__(
802
809
raise ValueError ("label should be provided." )
803
810
if image is None :
804
811
image = self .image
805
- if fg_indices is None or bg_indices is None :
806
- if self .fg_indices is not None and self .bg_indices is not None :
807
- fg_indices = self .fg_indices
808
- bg_indices = self .bg_indices
809
- else :
810
- fg_indices , bg_indices = map_binary_to_indices (label , image , self .image_threshold )
812
+
811
813
self .randomize (label , fg_indices , bg_indices , image )
812
814
results : List [np .ndarray ] = []
813
815
if self .centers is not None :
@@ -818,6 +820,139 @@ def __call__(
818
820
return results
819
821
820
822
823
+ class RandCropByLabelClasses (Randomizable , Transform ):
824
+ """
825
+ Crop random fixed sized regions with the center being a class based on the specified ratios of every class.
826
+ The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the
827
+ cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`::
828
+
829
+ image = np.array([
830
+ [[0.0, 0.3, 0.4, 0.2, 0.0],
831
+ [0.0, 0.1, 0.2, 0.1, 0.4],
832
+ [0.0, 0.3, 0.5, 0.2, 0.0],
833
+ [0.1, 0.2, 0.1, 0.1, 0.0],
834
+ [0.0, 0.1, 0.2, 0.1, 0.0]]
835
+ ])
836
+ label = np.array([
837
+ [[0, 0, 0, 0, 0],
838
+ [0, 1, 2, 1, 0],
839
+ [0, 1, 3, 0, 0],
840
+ [0, 0, 0, 0, 0],
841
+ [0, 0, 0, 0, 0]]
842
+ ])
843
+ cropper = RandCropByLabelClasses(
844
+ spatial_size=[3, 3],
845
+ ratios=[1, 2, 3, 1],
846
+ num_classes=4,
847
+ num_samples=2,
848
+ )
849
+ label_samples = cropper(img=label, label=label, image=image)
850
+
851
+ The 2 randomly cropped samples of `label` can be:
852
+ [[0, 1, 2], [[0, 0, 0],
853
+ [0, 1, 3], [1, 2, 1],
854
+ [0, 0, 0]] [1, 3, 0]]
855
+
856
+ If a dimension of the expected spatial size is bigger than the input image size,
857
+ will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped
858
+ results of several images may not have exactly same shape.
859
+
860
+ Args:
861
+ spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
862
+ if a dimension of ROI size is bigger than image size, will not crop that dimension of the image.
863
+ if its components have non-positive values, the corresponding size of `label` will be used.
864
+ for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,
865
+ the spatial size of output data will be [32, 40, 40].
866
+ ratios: specified ratios of every class in the label to generate crop centers, including background class.
867
+ if None, every class will have the same ratio to generate crop centers.
868
+ label: the label image that is used for finding every classes, if None, must set at `self.__call__`.
869
+ num_classes: number of classes for argmax label, not necessary for One-Hot label.
870
+ num_samples: number of samples (crop regions) to take in each list.
871
+ image: if image is not None, only return the indices of every class that are within the valid
872
+ region of the image (``image > image_threshold``).
873
+ image_threshold: if enabled `image`, use ``image > image_threshold`` to
874
+ determine the valid image content area and select class indices only in this area.
875
+ indices: if provided pre-computed indices of every class, will ignore above `image` and
876
+ `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array
877
+ of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first
878
+ and cache the results for better performance.
879
+
880
+ """
881
+
882
+ def __init__ (
883
+ self ,
884
+ spatial_size : Union [Sequence [int ], int ],
885
+ ratios : Optional [List [Union [float , int ]]] = None ,
886
+ label : Optional [np .ndarray ] = None ,
887
+ num_classes : Optional [int ] = None ,
888
+ num_samples : int = 1 ,
889
+ image : Optional [np .ndarray ] = None ,
890
+ image_threshold : float = 0.0 ,
891
+ indices : Optional [List [np .ndarray ]] = None ,
892
+ ) -> None :
893
+ self .spatial_size = ensure_tuple (spatial_size )
894
+ self .ratios = ratios
895
+ self .label = label
896
+ self .num_classes = num_classes
897
+ self .num_samples = num_samples
898
+ self .image = image
899
+ self .image_threshold = image_threshold
900
+ self .centers : Optional [List [List [np .ndarray ]]] = None
901
+ self .indices = indices
902
+
903
+ def randomize (
904
+ self ,
905
+ label : np .ndarray ,
906
+ indices : Optional [List [np .ndarray ]] = None ,
907
+ image : Optional [np .ndarray ] = None ,
908
+ ) -> None :
909
+ self .spatial_size = fall_back_tuple (self .spatial_size , default = label .shape [1 :])
910
+ indices_ : List [np .ndarray ]
911
+ if indices is None :
912
+ if self .indices is not None :
913
+ indices_ = self .indices
914
+ else :
915
+ indices_ = map_classes_to_indices (label , self .num_classes , image , self .image_threshold )
916
+ else :
917
+ indices_ = indices
918
+ self .centers = generate_label_classes_crop_centers (
919
+ self .spatial_size , self .num_samples , label .shape [1 :], indices_ , self .ratios , self .R
920
+ )
921
+
922
+ def __call__ (
923
+ self ,
924
+ img : np .ndarray ,
925
+ label : Optional [np .ndarray ] = None ,
926
+ image : Optional [np .ndarray ] = None ,
927
+ indices : Optional [List [np .ndarray ]] = None ,
928
+ ) -> List [np .ndarray ]:
929
+ """
930
+ Args:
931
+ img: input data to crop samples from based on the ratios of every class, assumes `img` is a
932
+ channel-first array.
933
+ label: the label image that is used for finding indices of every class, if None, use `self.label`.
934
+ image: optional image data to help select valid area, can be same as `img` or another image array.
935
+ use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`.
936
+ indices: list of indices for every class in the image, used to randomly select crop centers.
937
+
938
+ """
939
+ if label is None :
940
+ label = self .label
941
+ if label is None :
942
+ raise ValueError ("label should be provided." )
943
+ if image is None :
944
+ image = self .image
945
+
946
+ self .randomize (label , indices , image )
947
+ results : List [np .ndarray ] = []
948
+ if self .centers is not None :
949
+ for center in self .centers :
950
+ cropper = SpatialCrop (roi_center = tuple (center ), roi_size = self .spatial_size ) # type: ignore
951
+ results .append (cropper (img ))
952
+
953
+ return results
954
+
955
+
821
956
class ResizeWithPadOrCrop (Transform ):
822
957
"""
823
958
Resize an image to a target spatial size by either centrally cropping the image or
0 commit comments