Skip to content

Commit 8116506

Browse files
committed
Add torch validator tests
1 parent 3f718d3 commit 8116506

File tree

3 files changed

+723
-3
lines changed

3 files changed

+723
-3
lines changed
Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,244 @@
1-
"""Test depth validators."""
1+
"""Test Torch Depth Validators."""
22

33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from torchvision.tv_tensors import Image, Mask
10+
11+
from anomalib.data.validators.torch.depth import DepthBatchValidator, DepthValidator
12+
13+
14+
class TestDepthValidator:
15+
"""Test DepthValidator."""
16+
17+
def setup_method(self) -> None:
18+
"""Set up the validator for each test method."""
19+
self.validator = DepthValidator()
20+
21+
def test_validate_image_valid(self) -> None:
22+
"""Test validation of a valid depth image."""
23+
image = torch.rand(3, 224, 224)
24+
validated_image = self.validator.validate_image(image)
25+
assert isinstance(validated_image, Image)
26+
assert validated_image.shape == (3, 224, 224)
27+
assert validated_image.dtype == torch.float32
28+
29+
def test_validate_image_invalid_type(self) -> None:
30+
"""Test validation of a depth image with invalid type."""
31+
with pytest.raises(TypeError, match="Image must be a torch.Tensor"):
32+
self.validator.validate_image(np.random.default_rng().random((3, 224, 224)))
33+
34+
def test_validate_image_invalid_dimensions(self) -> None:
35+
"""Test validation of a depth image with invalid dimensions."""
36+
with pytest.raises(ValueError, match="Image must have shape"):
37+
self.validator.validate_image(torch.rand(224, 224))
38+
39+
def test_validate_image_invalid_channels(self) -> None:
40+
"""Test validation of a depth image with invalid number of channels."""
41+
with pytest.raises(ValueError, match="Image must have 3 channels"):
42+
self.validator.validate_image(torch.rand(1, 224, 224))
43+
44+
def test_validate_gt_label_valid(self) -> None:
45+
"""Test validation of a valid ground truth label."""
46+
label = torch.tensor(1)
47+
validated_label = self.validator.validate_gt_label(label)
48+
assert isinstance(validated_label, torch.Tensor)
49+
assert validated_label.dtype == torch.bool
50+
assert validated_label.item() is True
51+
52+
def test_validate_gt_label_none(self) -> None:
53+
"""Test validation of a None ground truth label."""
54+
assert self.validator.validate_gt_label(None) is None
55+
56+
def test_validate_gt_label_invalid_type(self) -> None:
57+
"""Test validation of a ground truth label with invalid type."""
58+
with pytest.raises(TypeError, match="Ground truth label must be an integer or a torch.Tensor"):
59+
self.validator.validate_gt_label("1")
60+
61+
def test_validate_gt_mask_valid(self) -> None:
62+
"""Test validation of a valid ground truth mask."""
63+
mask = torch.randint(0, 2, (1, 224, 224))
64+
validated_mask = self.validator.validate_gt_mask(mask)
65+
assert isinstance(validated_mask, Mask)
66+
assert validated_mask.shape == (224, 224)
67+
assert validated_mask.dtype == torch.bool
68+
69+
def test_validate_gt_mask_none(self) -> None:
70+
"""Test validation of a None ground truth mask."""
71+
assert self.validator.validate_gt_mask(None) is None
72+
73+
def test_validate_gt_mask_invalid_type(self) -> None:
74+
"""Test validation of a ground truth mask with invalid type."""
75+
with pytest.raises(TypeError, match="Mask must be a torch.Tensor"):
76+
self.validator.validate_gt_mask(np.random.default_rng().integers(0, 2, (224, 224)))
77+
78+
def test_validate_gt_mask_invalid_shape(self) -> None:
79+
"""Test validation of a ground truth mask with invalid shape."""
80+
with pytest.raises(ValueError, match="Mask must have 1 channel, got 2."):
81+
self.validator.validate_gt_mask(torch.randint(0, 2, (2, 224, 224)))
82+
83+
def test_validate_anomaly_map_valid(self) -> None:
84+
"""Test validation of a valid anomaly map."""
85+
anomaly_map = torch.rand(1, 224, 224)
86+
validated_map = self.validator.validate_anomaly_map(anomaly_map)
87+
assert isinstance(validated_map, Mask)
88+
assert validated_map.shape == (224, 224)
89+
assert validated_map.dtype == torch.float32
90+
91+
def test_validate_anomaly_map_none(self) -> None:
92+
"""Test validation of a None anomaly map."""
93+
assert self.validator.validate_anomaly_map(None) is None
94+
95+
def test_validate_anomaly_map_invalid_type(self) -> None:
96+
"""Test validation of an anomaly map with invalid type."""
97+
with pytest.raises(TypeError, match="Anomaly map must be a torch.Tensor"):
98+
self.validator.validate_anomaly_map(np.random.default_rng().random((224, 224)))
99+
100+
def test_validate_anomaly_map_invalid_shape(self) -> None:
101+
"""Test validation of an anomaly map with invalid shape."""
102+
with pytest.raises(ValueError, match="Anomaly map with 3 dimensions must have 1 channel, got 2."):
103+
self.validator.validate_anomaly_map(torch.rand(2, 224, 224))
104+
105+
def test_validate_pred_score_valid(self) -> None:
106+
"""Test validation of a valid prediction score."""
107+
score = torch.tensor(0.8)
108+
validated_score = self.validator.validate_pred_score(score)
109+
assert isinstance(validated_score, torch.Tensor)
110+
assert validated_score.dtype == torch.float32
111+
assert validated_score.item() == pytest.approx(0.8)
112+
113+
def test_validate_pred_score_none(self) -> None:
114+
"""Test validation of a None prediction score."""
115+
assert self.validator.validate_pred_score(None) is None
116+
117+
def test_validate_pred_score_invalid_shape(self) -> None:
118+
"""Test validation of a prediction score with invalid shape."""
119+
with pytest.raises(ValueError, match="Predicted score must be a scalar"):
120+
self.validator.validate_pred_score(torch.tensor([0.8, 0.9]))
121+
122+
123+
class TestDepthBatchValidator: # noqa: PLR0904
124+
"""Test DepthBatchValidator."""
125+
126+
def setup_method(self) -> None:
127+
"""Set up the validator for each test method."""
128+
self.validator = DepthBatchValidator()
129+
130+
def test_validate_image_valid(self) -> None:
131+
"""Test validation of a valid depth image batch."""
132+
image_batch = torch.rand(32, 3, 224, 224)
133+
validated_batch = self.validator.validate_image(image_batch)
134+
assert isinstance(validated_batch, Image)
135+
assert validated_batch.shape == (32, 3, 224, 224)
136+
assert validated_batch.dtype == torch.float32
137+
138+
def test_validate_image_invalid_type(self) -> None:
139+
"""Test validation of a depth image batch with invalid type."""
140+
with pytest.raises(TypeError, match="Image must be a torch.Tensor"):
141+
self.validator.validate_image(np.random.default_rng().random((32, 3, 224, 224)))
142+
143+
def test_validate_image_invalid_dimensions(self) -> None:
144+
"""Test validation of a depth image batch with invalid dimensions."""
145+
with pytest.raises(ValueError, match="Image must have shape"):
146+
self.validator.validate_image(torch.rand(32, 224, 224))
147+
148+
def test_validate_image_invalid_channels(self) -> None:
149+
"""Test validation of a depth image batch with invalid number of channels."""
150+
with pytest.raises(ValueError, match="Image must have 3 channels"):
151+
self.validator.validate_image(torch.rand(32, 1, 224, 224))
152+
153+
def test_validate_gt_label_valid(self) -> None:
154+
"""Test validation of valid ground truth labels."""
155+
labels = torch.tensor([0, 1, 1, 0])
156+
validated_labels = self.validator.validate_gt_label(labels, batch_size=4)
157+
assert isinstance(validated_labels, torch.Tensor)
158+
assert validated_labels.dtype == torch.bool
159+
assert torch.equal(validated_labels, torch.tensor([False, True, True, False]))
160+
161+
def test_validate_gt_label_none(self) -> None:
162+
"""Test validation of None ground truth labels."""
163+
assert self.validator.validate_gt_label(None, batch_size=4) is None
164+
165+
def test_validate_gt_label_invalid_type(self) -> None:
166+
"""Test validation of ground truth labels with invalid type."""
167+
with pytest.raises(ValueError, match="too many dimensions 'str'"):
168+
self.validator.validate_gt_label(["0", "1"], batch_size=2)
169+
170+
def test_validate_gt_label_invalid_dimensions(self) -> None:
171+
"""Test validation of ground truth labels with invalid dimensions."""
172+
with pytest.raises(ValueError, match="Ground truth label must be a 1-dimensional vector"):
173+
self.validator.validate_gt_label(torch.tensor([[0, 1], [1, 0]]), batch_size=2)
174+
175+
def test_validate_gt_mask_valid(self) -> None:
176+
"""Test validation of valid ground truth masks."""
177+
masks = torch.randint(0, 2, (4, 224, 224))
178+
validated_masks = self.validator.validate_gt_mask(masks, batch_size=4)
179+
assert isinstance(validated_masks, Mask)
180+
assert validated_masks.shape == (4, 224, 224)
181+
assert validated_masks.dtype == torch.bool
182+
183+
def test_validate_gt_mask_none(self) -> None:
184+
"""Test validation of None ground truth masks."""
185+
assert self.validator.validate_gt_mask(None, batch_size=4) is None
186+
187+
def test_validate_gt_mask_invalid_type(self) -> None:
188+
"""Test validation of ground truth masks with invalid type."""
189+
with pytest.raises(TypeError, match="Ground truth mask must be a torch.Tensor"):
190+
self.validator.validate_gt_mask([torch.zeros(224, 224)], batch_size=1)
191+
192+
def test_validate_gt_mask_invalid_dimensions(self) -> None:
193+
"""Test validation of ground truth masks with invalid dimensions."""
194+
with pytest.raises(ValueError, match="Ground truth mask must have 1 channel, got 2."):
195+
self.validator.validate_gt_mask(torch.zeros(4, 2, 224, 224), batch_size=4)
196+
197+
def test_validate_anomaly_map_valid(self) -> None:
198+
"""Test validation of a valid anomaly map batch."""
199+
anomaly_map = torch.rand(4, 224, 224)
200+
validated_map = self.validator.validate_anomaly_map(anomaly_map, batch_size=4)
201+
assert isinstance(validated_map, Mask)
202+
assert validated_map.shape == (4, 224, 224)
203+
assert validated_map.dtype == torch.float32
204+
205+
def test_validate_anomaly_map_none(self) -> None:
206+
"""Test validation of a None anomaly map batch."""
207+
assert self.validator.validate_anomaly_map(None, batch_size=4) is None
208+
209+
def test_validate_anomaly_map_invalid_shape(self) -> None:
210+
"""Test validation of an anomaly map batch with invalid shape."""
211+
with pytest.raises(ValueError, match="Anomaly map must have 1 channel, got 2."):
212+
self.validator.validate_anomaly_map(torch.rand(4, 2, 224, 224), batch_size=4)
213+
214+
def test_validate_pred_score_valid(self) -> None:
215+
"""Test validation of valid prediction scores."""
216+
scores = torch.tensor([0.1, 0.2, 0.3, 0.4])
217+
validated_scores = self.validator.validate_pred_score(scores, anomaly_map=None)
218+
assert torch.equal(validated_scores, scores)
219+
220+
def test_validate_pred_score_none_with_anomaly_map(self) -> None:
221+
"""Test validation of None prediction scores with anomaly map."""
222+
anomaly_map = torch.rand(4, 224, 224)
223+
computed_scores = self.validator.validate_pred_score(None, anomaly_map)
224+
assert computed_scores.shape == (4,)
225+
226+
def test_validate_pred_label_valid(self) -> None:
227+
"""Test validation of valid prediction labels."""
228+
labels = torch.tensor([[1], [0], [1], [1]])
229+
validated_labels = self.validator.validate_pred_label(labels)
230+
assert torch.equal(validated_labels, torch.tensor([[True], [False], [True], [True]]))
231+
232+
def test_validate_pred_label_none(self) -> None:
233+
"""Test validation of None prediction labels."""
234+
assert self.validator.validate_pred_label(None) is None
235+
236+
def test_validate_pred_label_invalid_type(self) -> None:
237+
"""Test validation of prediction labels with invalid type."""
238+
with pytest.raises(TypeError, match="Predicted label must be a torch.Tensor"):
239+
self.validator.validate_pred_label([1, 0, 1, 1])
240+
241+
def test_validate_pred_label_invalid_shape(self) -> None:
242+
"""Test validation of prediction labels with invalid shape."""
243+
with pytest.raises(ValueError, match="Predicted label must be 1-dimensional or 2-dimensional"):
244+
self.validator.validate_pred_label(torch.tensor([[[1]], [[0]], [[1]], [[1]]]))

0 commit comments

Comments
 (0)