Skip to content

Commit 0f2eb0e

Browse files
committed
add task type tests
1 parent 640f05a commit 0f2eb0e

File tree

2 files changed

+297
-0
lines changed

2 files changed

+297
-0
lines changed

tests/helpers/data.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,25 @@ def _generate_dummy_mvtec_dataset(
392392
mask_filename = mask_path / f"{i:03}{mask_suffix}{mask_extension}"
393393
self.image_generator.generate_image(label, image_filename, mask_filename)
394394

395+
def _generate_dummy_folder_dataset(self) -> None:
396+
"""Generate dummy folder dataset in a temporary directory."""
397+
# folder names
398+
normal_dir = self.root / self.normal_category
399+
abnormal_dir = self.root / self.abnormal_category
400+
mask_dir = self.root / "masks"
401+
402+
# generate images
403+
for i in range(self.num_train):
404+
label = LabelName.NORMAL
405+
image_filename = normal_dir / f"{self.normal_category}_{i:03}.png"
406+
self.image_generator.generate_image(label, image_filename)
407+
408+
for i in range(self.num_test):
409+
label = LabelName.ABNORMAL
410+
image_filename = abnormal_dir / f"{self.abnormal_category}_{i:03}.png"
411+
mask_filename = mask_dir / image_filename.name
412+
self.image_generator.generate_image(label, image_filename, mask_filename)
413+
395414
def _generate_dummy_btech_dataset(self) -> None:
396415
"""Generate dummy BeanTech dataset in directory using the same convention as BeanTech AD."""
397416
# BeanTech AD follows the same convention as MVTec AD.
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""Tests to check behaviour of the auxiliary components across different task types (classification, segmentation) ."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import copy
7+
from pathlib import Path
8+
from typing import Any
9+
10+
import pytest
11+
import torch
12+
from torchmetrics import Metric
13+
14+
from anomalib import LearningType
15+
from anomalib.data import AnomalibDataModule, Batch, Folder, ImageDataFormat
16+
from anomalib.engine import Engine
17+
from anomalib.metrics import AnomalibMetric, Evaluator
18+
from anomalib.models import AnomalibModule
19+
from anomalib.post_processing import OneClassPostProcessor
20+
from anomalib.visualization import ImageVisualizer
21+
from tests.helpers.data import DummyImageDatasetGenerator
22+
23+
24+
class DummyBaseModel(AnomalibModule):
25+
"""Dummy model for testing.
26+
27+
No training, and all auxiliary components default to None. This allows testing of the different components
28+
in isolation.
29+
"""
30+
31+
def training_step(self, *args, **kwargs) -> None:
32+
"""Dummy training step."""
33+
34+
@property
35+
def trainer_arguments(self) -> dict[str, Any]:
36+
"""Run for single epoch."""
37+
return {"max_epochs": 1}
38+
39+
@property
40+
def learning_type(self) -> LearningType:
41+
"""Return the learning type of the model."""
42+
return LearningType.ONE_CLASS
43+
44+
def configure_optimizers(self) -> None:
45+
"""No optimizers needed."""
46+
47+
def configure_preprocessor(self) -> None:
48+
"""No default pre-processor needed."""
49+
50+
def configure_post_processor(self) -> None:
51+
"""No default post-processor needed."""
52+
53+
def configure_evaluator(self) -> None:
54+
"""No default evaluator needed."""
55+
56+
def configure_visualizer(self) -> None:
57+
"""No default visualizer needed."""
58+
59+
60+
class DummyClassificationModel(DummyBaseModel):
61+
"""Dummy classification model for testing.
62+
63+
Validation step returns random image-only scores, to simulate a model that performs classification.
64+
"""
65+
66+
def validation_step(self, batch: Batch, *args, **kwargs) -> Batch:
67+
"""Validation steps that returns random image-level scores."""
68+
del args, kwargs
69+
batch.pred_score = torch.rand(batch.batch_size, device=self.device)
70+
return batch
71+
72+
73+
class DummySegmentationModel(DummyBaseModel):
74+
"""Dummy segmentation model for testing.
75+
76+
Validation step returns random image- and pixel-level scores, to simulate a model that performs segmentation.
77+
"""
78+
79+
def validation_step(self, batch: Batch, *args, **kwargs) -> Batch:
80+
"""Validation steps that returns random image- and pixel-level scores."""
81+
del args, kwargs
82+
batch.pred_score = torch.rand(batch.batch_size, device=self.device)
83+
batch.anomaly_map = torch.rand(batch.batch_size, *batch.image.shape[-2:], device=self.device)
84+
return batch
85+
86+
87+
class _DummyMetric(Metric):
88+
"""Dummy metric for testing."""
89+
90+
def update(self, *args, **kwargs) -> None:
91+
"""Dummy update method."""
92+
93+
def compute(self) -> None:
94+
"""Dummy compute method."""
95+
assert self.update_called # simulate failure to compute if states are not updated
96+
97+
98+
class DummyMetric(AnomalibMetric, _DummyMetric):
99+
"""Dummy Anomalib metric for testing."""
100+
101+
102+
@pytest.fixture
103+
def folder_dataset_path(project_path: Path) -> Path:
104+
"""Create a dummy folder dataset for testing."""
105+
data_path = project_path / "dataset"
106+
dataset_generator = DummyImageDatasetGenerator(
107+
data_format=ImageDataFormat.FOLDER,
108+
root=data_path,
109+
num_train=10,
110+
num_test=10,
111+
)
112+
dataset_generator.generate_dataset()
113+
return data_path
114+
115+
116+
@pytest.fixture
117+
def classification_datamodule(folder_dataset_path: Path) -> AnomalibDataModule:
118+
"""Create a classification datamodule for testing.
119+
120+
The datamodule is created with a folder dataset, that does not have a mask directory.
121+
"""
122+
# create the folder datamodule
123+
return Folder(
124+
name="cls_dataset",
125+
root=folder_dataset_path,
126+
normal_dir="good",
127+
abnormal_dir="bad",
128+
train_batch_size=1,
129+
eval_batch_size=1,
130+
num_workers=0,
131+
)
132+
133+
134+
@pytest.fixture
135+
def segmentation_datamodule(folder_dataset_path: Path) -> AnomalibDataModule:
136+
"""Create a segmentation datamodule for testing.
137+
138+
The datamodule is created with a folder dataset, that has a mask directory.
139+
"""
140+
# create the folder datamodule
141+
return Folder(
142+
name="seg_dataset",
143+
root=folder_dataset_path,
144+
normal_dir="good",
145+
abnormal_dir="bad",
146+
mask_dir="masks", # include masks for segmentation dataset
147+
train_batch_size=1,
148+
eval_batch_size=1,
149+
num_workers=0,
150+
)
151+
152+
153+
@pytest.fixture
154+
def image_and_pixel_evaluator() -> Evaluator:
155+
"""Create an evaluator with image- and pixel-level metrics for testing."""
156+
image_metric = DummyMetric(fields=["pred_score", "gt_label"], prefix="image_")
157+
pixel_metric = DummyMetric(fields=["anomaly_map", "gt_mask"], prefix="pixel_", strict=False)
158+
val_metrics = [image_metric, pixel_metric]
159+
test_metrics = copy.deepcopy(val_metrics)
160+
return Evaluator(val_metrics=[image_metric, pixel_metric], test_metrics=test_metrics)
161+
162+
163+
@pytest.fixture
164+
def engine(project_path: Path) -> Engine:
165+
"""Create an engine for testing.
166+
167+
Run on cpu to speed up tests.
168+
"""
169+
return Engine(accelerator="cpu", default_root_dir=project_path)
170+
171+
172+
class TestEvaluation:
173+
"""Test evaluation across task types.
174+
175+
Tests if image- and/or pixel- metrics are computed without errors for models and datasets with different task types.
176+
"""
177+
178+
@staticmethod
179+
def test_cls_model_cls_dataset(
180+
engine: Engine,
181+
classification_datamodule: AnomalibDataModule,
182+
image_and_pixel_evaluator: Evaluator,
183+
) -> None:
184+
"""Test classification model with classification dataset."""
185+
model = DummyClassificationModel(evaluator=image_and_pixel_evaluator)
186+
engine.train(model, datamodule=classification_datamodule)
187+
188+
@staticmethod
189+
def test_cls_model_seg_dataset(
190+
engine: Engine,
191+
segmentation_datamodule: AnomalibDataModule,
192+
image_and_pixel_evaluator: Evaluator,
193+
) -> None:
194+
"""Test classification model with segmentation dataset."""
195+
model = DummyClassificationModel(evaluator=image_and_pixel_evaluator)
196+
engine.train(model, datamodule=segmentation_datamodule)
197+
198+
@staticmethod
199+
def test_seg_model_cls_dataset(
200+
engine: Engine,
201+
classification_datamodule: AnomalibDataModule,
202+
image_and_pixel_evaluator: Evaluator,
203+
) -> None:
204+
"""Test segmentation model with classification dataset."""
205+
model = DummySegmentationModel(evaluator=image_and_pixel_evaluator)
206+
engine.train(model, datamodule=classification_datamodule)
207+
208+
@staticmethod
209+
def test_seg_model_seg_dataset(
210+
engine: Engine,
211+
segmentation_datamodule: AnomalibDataModule,
212+
image_and_pixel_evaluator: Evaluator,
213+
) -> None:
214+
"""Test segmentation model with segmentation dataset."""
215+
model = DummySegmentationModel(evaluator=image_and_pixel_evaluator)
216+
engine.train(model, datamodule=segmentation_datamodule)
217+
218+
219+
class TestPostProcessing:
220+
"""Tests post-processing across task types.
221+
222+
Tests if post-processing is applied without errors for models and datasets with different task types.
223+
"""
224+
225+
@staticmethod
226+
def test_cls_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None:
227+
"""Test classification model with classification dataset."""
228+
model = DummyClassificationModel(post_processor=OneClassPostProcessor())
229+
engine.train(model, datamodule=classification_datamodule)
230+
231+
@staticmethod
232+
def test_cls_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None:
233+
"""Test classification model with segmentation dataset."""
234+
model = DummyClassificationModel(post_processor=OneClassPostProcessor())
235+
engine.train(model, datamodule=segmentation_datamodule)
236+
237+
@staticmethod
238+
def test_seg_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None:
239+
"""Test segmentation model with classification dataset."""
240+
model = DummySegmentationModel(post_processor=OneClassPostProcessor())
241+
engine.train(model, datamodule=classification_datamodule)
242+
243+
@staticmethod
244+
def test_seg_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None:
245+
"""Test segmentation model with segmentation dataset."""
246+
model = DummySegmentationModel(post_processor=OneClassPostProcessor())
247+
engine.train(model, datamodule=segmentation_datamodule)
248+
249+
250+
class TestVisualization:
251+
"""Tests visualization across task types.
252+
253+
Tests if visualizations are created without errors for models and datasets with different task types.
254+
"""
255+
256+
@staticmethod
257+
def test_cls_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None:
258+
"""Test classification model with classification dataset."""
259+
model = DummyClassificationModel(visualizer=ImageVisualizer())
260+
engine.train(model, datamodule=classification_datamodule)
261+
262+
@staticmethod
263+
def test_cls_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None:
264+
"""Test classification model with segmentation dataset."""
265+
model = DummyClassificationModel(visualizer=ImageVisualizer())
266+
engine.train(model, datamodule=segmentation_datamodule)
267+
268+
@staticmethod
269+
def test_seg_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None:
270+
"""Test segmentation model with classification dataset."""
271+
model = DummySegmentationModel(visualizer=ImageVisualizer())
272+
engine.train(model, datamodule=classification_datamodule)
273+
274+
@staticmethod
275+
def test_seg_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None:
276+
"""Test segmentation model with segmentation dataset."""
277+
model = DummySegmentationModel(visualizer=ImageVisualizer())
278+
engine.train(model, datamodule=segmentation_datamodule)

0 commit comments

Comments
 (0)