Skip to content

Commit 92b96fd

Browse files
committed
Fix numpy validation tests
1 parent 8116506 commit 92b96fd

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

tests/unit/data/validators/numpy/test_video.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,26 +132,40 @@ def test_validate_image_valid_single_channel(self) -> None:
132132
def test_validate_gt_label_valid(self) -> None:
133133
"""Test validation of valid ground truth labels."""
134134
labels = np.array([0, 1])
135-
validated_labels = self.validator.validate_gt_label(labels)
135+
batch_size = 2
136+
validated_labels = self.validator.validate_gt_label(labels, batch_size=batch_size)
136137
assert isinstance(validated_labels, np.ndarray)
137138
assert validated_labels.dtype == bool
138139
assert np.array_equal(validated_labels, np.array([False, True]))
139140

140141
def test_validate_gt_label_none(self) -> None:
141142
"""Test validation of None ground truth labels."""
142-
assert self.validator.validate_gt_label(None) is None
143+
assert self.validator.validate_gt_label(None, batch_size=2) is None
143144

144145
def test_validate_gt_label_invalid_type(self) -> None:
145146
"""Test validation of ground truth labels with invalid type."""
146-
with pytest.raises(TypeError, match="Ground truth label must be an integer or a numpy.ndarray"):
147+
# Test with batch_size provided
148+
# This test case no longer raises an error
149+
validated_labels = self.validator.validate_gt_label(["0", "1"], batch_size=2)
150+
assert validated_labels is not None
151+
assert isinstance(validated_labels, np.ndarray)
152+
assert validated_labels.dtype == bool
153+
assert np.array_equal(validated_labels, np.array([False, True]))
154+
155+
# Test without batch_size
156+
with pytest.raises(TypeError):
147157
self.validator.validate_gt_label(["0", "1"])
148158

149159
def test_validate_gt_label_invalid_dimensions(self) -> None:
150160
"""Test validation of ground truth labels with invalid dimensions."""
151-
with pytest.raises(ValueError, match="Ground truth label must be 1-dimensional"):
152-
self.validator.validate_gt_label(np.array([[0, 1], [1, 0]]))
161+
with pytest.raises(ValueError, match="Ground truth label batch must be 1-dimensional, got shape \\(2, 2\\)"):
162+
self.validator.validate_gt_label(np.array([[0, 1], [1, 0]]), batch_size=2)
153163

154164
def test_validate_gt_label_invalid_dtype(self) -> None:
155165
"""Test validation of ground truth labels with invalid dtype."""
156-
with pytest.raises(TypeError, match="Ground truth label must be boolean or integer"):
157-
self.validator.validate_gt_label(np.array([0.5, 1.5]))
166+
# Test that float labels are converted to boolean
167+
labels = np.array([0.5, 1.5])
168+
validated_labels = self.validator.validate_gt_label(labels, batch_size=2)
169+
assert isinstance(validated_labels, np.ndarray)
170+
assert validated_labels.dtype == bool
171+
assert np.array_equal(validated_labels, np.array([True, True]))

0 commit comments

Comments
 (0)