Skip to content

Commit 6df0233

Browse files
authored
Merge branch 'main' into v2
2 parents c7a2b45 + 9357f9c commit 6df0233

File tree

9 files changed

+91
-11
lines changed

9 files changed

+91
-11
lines changed

.github/actions/pytest/action.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ runs:
180180
181181
- name: Upload test results
182182
if: always() && steps.test-execution.outcome == 'failure'
183-
uses: actions/upload-artifact@v3
183+
uses: actions/upload-artifact@v4
184184
with:
185185
name: pytest-results-${{ inputs.test-type }}
186186
path: pytest_output.log

.github/workflows/pr.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ name: PR Checks
4040

4141
on:
4242
pull_request:
43-
branches: [main, "feature/**"]
43+
branches: ["main", "feature/**"]
44+
merge_group:
45+
branches: [main]
4446

4547
concurrency:
4648
group: ${{ github.workflow }}-${{ github.ref }}

.github/workflows/pre_merge.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Pre-Merge Checks
22
permissions: read-all
33

44
on:
5-
push:
5+
merge_group:
66
branches: [main]
77
pull_request:
88
types:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:adee94133a01a9b5502d2c77e0016aee4e1f2954c28eb5b7d928f5df66db42a7
3-
size 378257
2+
oid sha256:83fd473e527968ff540f506b2e63f013e4216e57f488fb1f1e6c8f8ef5632572
3+
size 378250

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ lint.ignore = [
191191
# End of disable rules
192192

193193
# flake8-annotations
194-
"ANN101", # Missing-type-self
195194
"ANN002", # Missing type annotation for *args
196195
"ANN003", # Missing type annotation for **kwargs
197196

src/anomalib/engine/engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ def export(
728728
model: AnomalibModule,
729729
export_type: ExportType | str,
730730
export_root: str | Path | None = None,
731+
model_file_name: str = "model",
731732
input_size: tuple[int, int] | None = None,
732733
compression_type: CompressionType | None = None,
733734
datamodule: AnomalibDataModule | None = None,
@@ -742,6 +743,8 @@ def export(
742743
export_type (ExportType): Export type.
743744
export_root (str | Path | None, optional): Path to the output directory. If it is not set, the model is
744745
exported to trainer.default_root_dir. Defaults to None.
746+
model_file_name (str = "model"): Name of the exported model file. If it is not set, the model is
747+
is called "model". Defaults to "model".
745748
input_size (tuple[int, int] | None, optional): A statis input shape for the model, which is exported to ONNX
746749
and OpenVINO format. Defaults to None.
747750
compression_type (CompressionType | None, optional): Compression type for OpenVINO exporting only.
@@ -798,15 +801,18 @@ def export(
798801
if export_type == ExportType.TORCH:
799802
exported_model_path = model.to_torch(
800803
export_root=export_root,
804+
model_file_name=model_file_name,
801805
)
802806
elif export_type == ExportType.ONNX:
803807
exported_model_path = model.to_onnx(
804808
export_root=export_root,
809+
model_file_name=model_file_name,
805810
input_size=input_size,
806811
)
807812
elif export_type == ExportType.OPENVINO:
808813
exported_model_path = model.to_openvino(
809814
export_root=export_root,
815+
model_file_name=model_file_name,
810816
input_size=input_size,
811817
compression_type=compression_type,
812818
datamodule=datamodule,

src/anomalib/metrics/min_max.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def update(self, predictions: torch.Tensor, *args, **kwargs) -> None:
8585
del args, kwargs # These variables are not used.
8686

8787
self.min = torch.min(self.min, torch.min(predictions))
88-
self.max = torch.max(self.min, torch.max(predictions))
88+
self.max = torch.max(self.max, torch.max(predictions))
8989

9090
def compute(self) -> tuple[torch.Tensor, torch.Tensor]:
9191
"""Compute final minimum and maximum values.

src/anomalib/models/components/base/export_mixin.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@ class ExportMixin:
8080
def to_torch(
8181
self,
8282
export_root: Path | str,
83+
model_file_name: str = "model",
8384
) -> Path:
8485
"""Export model to PyTorch format.
8586
8687
Args:
8788
export_root (Path | str): Path to the output folder
89+
model_file_name (str): Name of the exported model
8890
8991
Returns:
9092
Path: Path to the exported PyTorch model (.pt file)
@@ -99,7 +101,7 @@ def to_torch(
99101
PosixPath('./exports/weights/torch/model.pt')
100102
"""
101103
export_root = _create_export_root(export_root, ExportType.TORCH)
102-
pt_model_path = export_root / "model.pt"
104+
pt_model_path = export_root / (model_file_name + ".pt")
103105
torch.save(
104106
obj={"model": self},
105107
f=pt_model_path,
@@ -109,12 +111,14 @@ def to_torch(
109111
def to_onnx(
110112
self,
111113
export_root: Path | str,
114+
model_file_name: str = "model",
112115
input_size: tuple[int, int] | None = None,
113116
) -> Path:
114117
"""Export model to ONNX format.
115118
116119
Args:
117120
export_root (Path | str): Path to the output folder
121+
model_file_name (str): Name of the exported model.
118122
input_size (tuple[int, int] | None): Input image dimensions (height, width).
119123
If ``None``, uses dynamic input shape. Defaults to ``None``
120124
@@ -143,7 +147,7 @@ def to_onnx(
143147
if input_size
144148
else {"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}}
145149
)
146-
onnx_path = export_root / "model.onnx"
150+
onnx_path = export_root / (model_file_name + ".onnx")
147151
# apply pass through the model to get the output names
148152
assert isinstance(self, LightningModule) # mypy
149153
output_names = [name for name, value in self.eval()(input_shape)._asdict().items() if value is not None]
@@ -162,6 +166,7 @@ def to_onnx(
162166
def to_openvino(
163167
self,
164168
export_root: Path | str,
169+
model_file_name: str = "model",
165170
input_size: tuple[int, int] | None = None,
166171
compression_type: CompressionType | None = None,
167172
datamodule: AnomalibDataModule | None = None,
@@ -173,6 +178,7 @@ def to_openvino(
173178
174179
Args:
175180
export_root (Path | str): Path to the output folder
181+
model_file_name (str): Name of the exported model
176182
input_size (tuple[int, int] | None): Input image dimensions (height, width).
177183
If ``None``, uses dynamic input shape. Defaults to ``None``
178184
compression_type (CompressionType | None): Type of compression to apply.
@@ -218,9 +224,9 @@ def to_openvino(
218224
import openvino as ov
219225

220226
with TemporaryDirectory() as onnx_directory:
221-
model_path = self.to_onnx(onnx_directory, input_size)
227+
model_path = self.to_onnx(onnx_directory, model_file_name, input_size)
222228
export_root = _create_export_root(export_root, ExportType.OPENVINO)
223-
ov_model_path = export_root / "model.xml"
229+
ov_model_path = export_root / (model_file_name + ".xml")
224230
ov_args = {} if ov_args is None else ov_args
225231

226232
model = ov.convert_model(model_path, **ov_args)

tests/unit/metrics/test_min_max.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Test MinMax metric."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import torch
7+
8+
from anomalib.metrics import MinMax # Assuming the metric is part of `anomalib`
9+
10+
11+
def test_initialization() -> None:
12+
"""Test if the metric initializes with correct default values."""
13+
metric = MinMax()
14+
assert torch.isinf(metric.min), "Initial min should be positive infinity."
15+
assert metric.min > 0, "Initial min should be positive infinity."
16+
assert torch.isinf(metric.max), "Initial max should be negative infinity."
17+
assert metric.max < 0, "Initial max should be negative infinity."
18+
19+
20+
def test_update_single_batch() -> None:
21+
"""Test updating the metric with a single batch."""
22+
metric = MinMax()
23+
batch = torch.tensor([1.0, 2.0, 3.0, -1.0])
24+
metric.update(batch)
25+
26+
assert metric.min.item() == -1.0, "Min should be -1.0 after single batch update."
27+
assert metric.max.item() == 3.0, "Max should be 3.0 after single batch update."
28+
29+
30+
def test_update_multiple_batches() -> None:
31+
"""Test updating the metric with multiple batches."""
32+
metric = MinMax()
33+
batch1 = torch.tensor([0.5, 1.5, 3.0])
34+
batch2 = torch.tensor([-0.5, 0.0, 2.5])
35+
36+
metric.update(batch1)
37+
metric.update(batch2)
38+
39+
assert metric.min.item() == -0.5, "Min should be -0.5 after multiple batch updates."
40+
assert metric.max.item() == 3.0, "Max should be 3.0 after multiple batch updates."
41+
42+
43+
def test_compute() -> None:
44+
"""Test computation of the min and max values after updates."""
45+
metric = MinMax()
46+
batch1 = torch.tensor([1.0, 2.0])
47+
batch2 = torch.tensor([-1.0, 0.0])
48+
49+
metric.update(batch1)
50+
metric.update(batch2)
51+
52+
min_val, max_val = metric.compute()
53+
54+
assert min_val.item() == -1.0, "Computed min should be -1.0."
55+
assert max_val.item() == 2.0, "Computed max should be 2.0."
56+
57+
58+
def test_no_updates() -> None:
59+
"""Test behavior when no updates are made to the metric."""
60+
metric = MinMax()
61+
62+
min_val, max_val = metric.compute()
63+
64+
assert torch.isinf(min_val), "Min should remain positive infinity with no updates."
65+
assert min_val > 0, "Min should remain positive infinity with no updates."
66+
assert torch.isinf(max_val), "Max should remain negative infinity with no updates."
67+
assert max_val < 0, "Max should remain negative infinity with no updates."

0 commit comments

Comments
 (0)