@@ -132,26 +132,40 @@ def test_validate_image_valid_single_channel(self) -> None:
132132 def test_validate_gt_label_valid (self ) -> None :
133133 """Test validation of valid ground truth labels."""
134134 labels = np .array ([0 , 1 ])
135- validated_labels = self .validator .validate_gt_label (labels )
135+ batch_size = 2
136+ validated_labels = self .validator .validate_gt_label (labels , batch_size = batch_size )
136137 assert isinstance (validated_labels , np .ndarray )
137138 assert validated_labels .dtype == bool
138139 assert np .array_equal (validated_labels , np .array ([False , True ]))
139140
140141 def test_validate_gt_label_none (self ) -> None :
141142 """Test validation of None ground truth labels."""
142- assert self .validator .validate_gt_label (None ) is None
143+ assert self .validator .validate_gt_label (None , batch_size = 2 ) is None
143144
144145 def test_validate_gt_label_invalid_type (self ) -> None :
145146 """Test validation of ground truth labels with invalid type."""
146- with pytest .raises (TypeError , match = "Ground truth label must be an integer or a numpy.ndarray" ):
147+ # Test with batch_size provided
148+ # This test case no longer raises an error
149+ validated_labels = self .validator .validate_gt_label (["0" , "1" ], batch_size = 2 )
150+ assert validated_labels is not None
151+ assert isinstance (validated_labels , np .ndarray )
152+ assert validated_labels .dtype == bool
153+ assert np .array_equal (validated_labels , np .array ([False , True ]))
154+
155+ # Test without batch_size
156+ with pytest .raises (TypeError ):
147157 self .validator .validate_gt_label (["0" , "1" ])
148158
149159 def test_validate_gt_label_invalid_dimensions (self ) -> None :
150160 """Test validation of ground truth labels with invalid dimensions."""
151- with pytest .raises (ValueError , match = "Ground truth label must be 1-dimensional" ):
152- self .validator .validate_gt_label (np .array ([[0 , 1 ], [1 , 0 ]]))
161+ with pytest .raises (ValueError , match = "Ground truth label batch must be 1-dimensional, got shape \\ (2, 2 \\ ) " ):
162+ self .validator .validate_gt_label (np .array ([[0 , 1 ], [1 , 0 ]]), batch_size = 2 )
153163
154164 def test_validate_gt_label_invalid_dtype (self ) -> None :
155165 """Test validation of ground truth labels with invalid dtype."""
156- with pytest .raises (TypeError , match = "Ground truth label must be boolean or integer" ):
157- self .validator .validate_gt_label (np .array ([0.5 , 1.5 ]))
166+ # Test that float labels are converted to boolean
167+ labels = np .array ([0.5 , 1.5 ])
168+ validated_labels = self .validator .validate_gt_label (labels , batch_size = 2 )
169+ assert isinstance (validated_labels , np .ndarray )
170+ assert validated_labels .dtype == bool
171+ assert np .array_equal (validated_labels , np .array ([True , True ]))
0 commit comments