4040 >>> from anomalib.metrics import create_anomalib_metric
4141 >>> F1Score = create_anomalib_metric(BinaryF1Score)
4242 >>> f1_score = F1Score(fields=["pred_label", "gt_label"])
43+
44+ Strict mode vs non-strict mode::
45+
46+ >>> F1Score = create_anomalib_metric(BinaryF1Score)
47+ >>>
48+ >>> # create metric in strict mode (default), and non-strict mode
49+ >>> f1_score_strict = F1Score(fields=["pred_label", "gt_label"], strict=True)
50+ >>> f1_score_nonstrict = F1Score(fields=["pred_label", "gt_label"], strict=False)
51+ >>>
52+ >>> # create a batch in which 'pred_label' field is None
53+ >>> batch = ImageBatch(
54+ ... image=torch.rand(4, 3, 256, 256),
55+ ... gt_label=torch.tensor([0, 0, 1, 1])
56+ ... )
57+ >>>
58+ >>> f1_score_strict.update(batch) # ValueError
59+ >>> f1_score_strict.compute() # UserWarning, tensor(0.)
60+ >>>
61+ >>> f1_score_nonstrict.update(batch) # No error
62+ >>> f1_score_nonstrict.compute() # None
4363"""
4464
45- # Copyright (C) 2024 Intel Corporation
65+ # Copyright (C) 2024-2025 Intel Corporation
4666# SPDX-License-Identifier: Apache-2.0
4767
4868from collections .abc import Sequence
4969
70+ import torch
5071from torchmetrics import Metric , MetricCollection
5172
5273from anomalib .data import Batch
@@ -67,6 +88,7 @@ class AnomalibMetric:
6788 fields (Sequence[str] | None): Names of fields to extract from batch.
6889 If None, uses class's ``default_fields``. Required if no defaults.
6990 prefix (str): Prefix added to metric name. Defaults to "".
91+ strict (bool): Whether to raise an error if batch is missing fields.
7092 **kwargs: Additional arguments passed to parent metric class.
7193
7294 Raises:
@@ -97,6 +119,7 @@ def __init__(
97119 self ,
98120 fields : Sequence [str ] | None = None ,
99121 prefix : str = "" ,
122+ strict : bool = True ,
100123 ** kwargs ,
101124 ) -> None :
102125 fields = fields or getattr (self , "default_fields" , None )
@@ -109,6 +132,7 @@ def __init__(
109132 raise ValueError (msg )
110133 self .fields = fields
111134 self .name = prefix + self .__class__ .__name__
135+ self .strict = strict
112136 super ().__init__ (** kwargs )
113137
114138 def __init_subclass__ (cls , ** kwargs ) -> None :
@@ -132,11 +156,40 @@ def update(self, batch: Batch, *args, **kwargs) -> None:
132156 """
133157 for key in self .fields :
134158 if getattr (batch , key , None ) is None :
135- msg = f"Batch object is missing required field: { key } "
159+ # We cannot update the metric if the batch is missing required fields,
160+ # so we need to decrement the update count of the super class.
161+ self ._update_count -= 1 # type: ignore[attr-defined]
162+ if not self .strict :
163+ # If not in strict mode, skip updating the metric but don't raise an error
164+ return
165+ # otherwise, raise an error
166+ if not hasattr (batch , key ):
167+ msg = (
168+ f"Cannot update metric of type { type (self )} . Passed dataclass instance "
169+ f"is missing required field: { key } "
170+ )
171+ else :
172+ msg = (
173+ f"Cannot update metric of type { type (self )} . Passed dataclass instance "
174+ f"does not have a value for field with name { key } ."
175+ )
136176 raise ValueError (msg )
177+
137178 values = [getattr (batch , key ) for key in self .fields ]
138179 super ().update (* values , * args , ** kwargs ) # type: ignore[misc]
139180
181+ def compute (self ) -> torch .Tensor :
182+ """Compute the metric value.
183+
184+ If the metric has not been updated, and metric is not in strict mode, return None.
185+
186+ Returns:
187+ torch.Tensor: Computed metric value or None.
188+ """
189+ if self ._update_count == 0 and not self .strict : # type: ignore[attr-defined]
190+ return None
191+ return super ().compute () # type: ignore[misc]
192+
140193
141194def create_anomalib_metric (metric_cls : type ) -> type :
142195 """Create an Anomalib version of a torchmetrics metric.
0 commit comments