diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index 62708ec53a78..8ae7dc72a7d5 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, Callable, Union +from typing import Any, Callable, Tuple, Union, cast import torch @@ -47,8 +47,7 @@ def __init__( ): if not callable(op): raise TypeError("Argument op should be a callable, but given {}".format(type(op))) - self.accumulator = None - self.num_examples = None + self._op = op super(VariableAccumulation, self).__init__(output_transform=output_transform, device=device) @@ -56,7 +55,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device) - self.num_examples = torch.tensor(0, dtype=torch.long, device=self._device) + self.num_examples = 0 def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None: if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)): @@ -73,14 +72,14 @@ def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None: self.accumulator = self._op(self.accumulator, output) - if hasattr(output, "shape"): + if isinstance(output, torch.Tensor): self.num_examples += output.shape[0] if len(output.shape) > 1 else 1 else: self.num_examples += 1 @sync_all_reduce("accumulator", "num_examples") - def compute(self) -> list: - return [self.accumulator, self.num_examples] + def compute(self) -> Tuple[torch.Tensor, int]: + return self.accumulator, self.num_examples class Average(VariableAccumulation): @@ -125,18 +124,18 @@ class Average(VariableAccumulation): def __init__( self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") ): - def _mean_op(a, x): - if isinstance(x, torch.Tensor) and x.ndim > 1: + def _mean_op(a: Union[float, torch.Tensor], x: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + if isinstance(x, torch.Tensor) and x.ndim > 1: # type: ignore[attr-defined] x = x.sum(dim=0) return a + x super(Average, self).__init__(op=_mean_op, output_transform=output_transform, device=device) @sync_all_reduce("accumulator", "num_examples") - def compute(self) -> Union[Any, torch.Tensor, numbers.Number]: + def compute(self) -> Union[torch.Tensor, numbers.Number]: if self.num_examples < 1: raise NotComputableError( - "{} must have at least one example before" " it can be computed.".format(self.__class__.__name__) + "{} must have at least one example before it can be computed.".format(self.__class__.__name__) ) return self.accumulator / self.num_examples @@ -173,21 +172,26 @@ class GeometricAverage(VariableAccumulation): def __init__( self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") ): - def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor: + def _geom_op(a: torch.Tensor, x: Union[numbers.Number, torch.Tensor]) -> torch.Tensor: if not isinstance(x, torch.Tensor): x = torch.tensor(x) x = torch.log(x) - if x.ndim > 1: + if x.ndim > 1: # type: ignore[attr-defined] x = x.sum(dim=0) return a + x super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform, device=device) @sync_all_reduce("accumulator", "num_examples") - def compute(self) -> torch.Tensor: + def compute(self) -> Union[torch.Tensor, numbers.Number]: if self.num_examples < 1: raise NotComputableError( - "{} must have at least one example before" " it can be computed.".format(self.__class__.__name__) + "{} must have at least one example before it can be computed.".format(self.__class__.__name__) ) - return torch.exp(self.accumulator / self.num_examples) + tensor = torch.exp(self.accumulator / self.num_examples) + + if tensor.numel() == 1: + return cast(numbers.Number, tensor.item()) + + return tensor diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 7d6c939e4b53..47e357e62eda 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,4 +1,4 @@ -from typing import Callable, Sequence, Union +from typing import Callable, Optional, Sequence, Tuple, Union import torch @@ -16,8 +16,8 @@ def __init__( device: Union[str, torch.device] = torch.device("cpu"), ): self._is_multilabel = is_multilabel - self._type = None - self._num_classes = None + self._type = None # type: Optional[str] + self._num_classes = None # type: Optional[int] super(_BaseClassification, self).__init__(output_transform=output_transform, device=device) def reset(self) -> None: @@ -35,7 +35,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None: ) y_shape = y.shape - y_pred_shape = y_pred.shape + y_pred_shape = y_pred.shape # type: Tuple[int, ...] if y.ndimension() + 1 == y_pred.ndimension(): y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:] @@ -134,8 +134,6 @@ def __init__( is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), ): - self._num_correct = None - self._num_examples = None super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device) @reinit__is_reduced @@ -167,7 +165,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") - def compute(self) -> torch.Tensor: + def compute(self) -> float: if self._num_examples == 0: raise NotComputableError("Accuracy must have at least one example before it can be computed.") return self._num_correct.item() / self._num_examples diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 3c797efaf3e4..f4fd840699b7 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -1,5 +1,5 @@ import numbers -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Tuple, Union import torch @@ -54,7 +54,6 @@ def __init__( self.num_classes = num_classes self._num_examples = 0 self.average = average - self.confusion_matrix = None super(ConfusionMatrix, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced @@ -67,7 +66,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None: if y_pred.ndimension() < 2: raise ValueError( - "y_pred must have shape (batch_size, num_categories, ...), " "but given {}".format(y_pred.shape) + "y_pred must have shape (batch_size, num_categories, ...), but given {}".format(y_pred.shape) ) if y_pred.shape[1] != self.num_classes: @@ -83,7 +82,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None: ) y_shape = y.shape - y_pred_shape = y_pred.shape + y_pred_shape = y_pred.shape # type: Tuple[int, ...] if y.ndimension() + 1 == y_pred.ndimension(): y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:] @@ -124,13 +123,12 @@ def compute(self) -> torch.Tensor: @staticmethod def normalize(matrix: torch.Tensor, average: str) -> torch.Tensor: - if average not in ("recall", "precision"): - raise ValueError("Argument average one of 'samples', 'recall', 'precision'") - if average == "recall": return matrix / (matrix.sum(dim=1).unsqueeze(1) + 1e-15) elif average == "precision": return matrix / (matrix.sum(dim=0) + 1e-15) + else: + raise ValueError("Argument average should be one of 'samples', 'recall', 'precision'") def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda: @@ -170,14 +168,15 @@ def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambd cm = cm.type(torch.DoubleTensor) iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15) if ignore_index is not None: + ignore_idx = ignore_index # type: int # used due to typing issues with mympy - def ignore_index_fn(iou_vector): - if ignore_index >= len(iou_vector): + def ignore_index_fn(iou_vector: torch.Tensor) -> torch.Tensor: + if ignore_idx >= len(iou_vector): raise ValueError( - "ignore_index {} is larger than the length of IoU vector {}".format(ignore_index, len(iou_vector)) + "ignore_index {} is larger than the length of IoU vector {}".format(ignore_idx, len(iou_vector)) ) indices = list(range(len(iou_vector))) - indices.remove(ignore_index) + indices.remove(ignore_idx) return iou_vector[indices] return MetricsLambda(ignore_index_fn, iou) @@ -282,14 +281,15 @@ def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15) if ignore_index is not None: + ignore_idx = ignore_index # type: int # used due to typing issues with mympy def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor: - if ignore_index >= len(dice_vector): + if ignore_idx >= len(dice_vector): raise ValueError( - "ignore_index {} is larger than the length of Dice vector {}".format(ignore_index, len(dice_vector)) + "ignore_index {} is larger than the length of Dice vector {}".format(ignore_idx, len(dice_vector)) ) indices = list(range(len(dice_vector))) - indices.remove(ignore_index) + indices.remove(ignore_idx) return dice_vector[indices] return MetricsLambda(ignore_index_fn, dice) diff --git a/ignite/metrics/epoch_metric.py b/ignite/metrics/epoch_metric.py index dab209834915..73f9a7343242 100644 --- a/ignite/metrics/epoch_metric.py +++ b/ignite/metrics/epoch_metric.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Sequence, Union +from typing import Callable, List, Sequence, Union, cast import torch @@ -54,13 +54,11 @@ def __init__( output_transform: Callable = lambda x: x, check_compute_fn: bool = True, device: Union[str, torch.device] = torch.device("cpu"), - ): + ) -> None: if not callable(compute_fn): raise TypeError("Argument compute_fn should be callable.") - self._predictions = None - self._targets = None self.compute_fn = compute_fn self._check_compute_fn = check_compute_fn @@ -68,10 +66,10 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._predictions = [] - self._targets = [] + self._predictions = [] # type: List[torch.Tensor] + self._targets = [] # type: List[torch.Tensor] - def _check_shape(self, output): + def _check_shape(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if y_pred.ndimension() not in (1, 2): raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).") @@ -83,7 +81,7 @@ def _check_shape(self, output): if not torch.equal(y ** 2, y): raise ValueError("Targets should be binary (0 or 1).") - def _check_type(self, output): + def _check_type(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if len(self._predictions) < 1: return @@ -97,7 +95,7 @@ def _check_type(self, output): dtype_targets = self._targets[-1].dtype if dtype_targets != y.dtype: raise ValueError( - "Incoherent types between input y and stored targets: " "{} vs {}".format(dtype_targets, y.dtype) + "Incoherent types between input y and stored targets: {} vs {}".format(dtype_targets, y.dtype) ) @reinit__is_reduced @@ -125,7 +123,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: except Exception as e: warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), EpochMetricWarning) - def compute(self) -> None: + def compute(self) -> float: if len(self._predictions) < 1 or len(self._targets) < 1: raise NotComputableError("EpochMetric must have at least one example before it can be computed.") @@ -136,8 +134,8 @@ def compute(self) -> None: if ws > 1 and not self._is_reduced: # All gather across all processes - _prediction_tensor = idist.all_gather(_prediction_tensor) - _target_tensor = idist.all_gather(_target_tensor) + _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) + _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) self._is_reduced = True result = 0.0 @@ -147,7 +145,7 @@ def compute(self) -> None: if ws > 1: # broadcast result to all processes - result = idist.broadcast(result, src=0) + result = cast(float, idist.broadcast(result, src=0)) # type: ignore[arg-type] return result diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index eb6776a17eb8..769c0febfc10 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -46,7 +46,7 @@ def Fbeta( if precision is None: precision = Precision( - output_transform=(lambda x: x) if output_transform is None else output_transform, + output_transform=(lambda x: x) if output_transform is None else output_transform, # type: ignore[arg-type] average=False, device=device, ) @@ -55,7 +55,7 @@ def Fbeta( if recall is None: recall = Recall( - output_transform=(lambda x: x) if output_transform is None else output_transform, + output_transform=(lambda x: x) if output_transform is None else output_transform, # type: ignore[arg-type] average=False, device=device, ) diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index 447cbbf63fd8..b34c33710646 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -1,11 +1,11 @@ -from typing import Callable, Optional, Union +from typing import Any, Callable, Union import torch import ignite.distributed as idist -from ignite.engine import Events +from ignite.engine import Engine, Events from ignite.handlers.timing import Timer -from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce class Frequency(Metric): @@ -39,15 +39,11 @@ class Frequency(Metric): def __init__( self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") - ): - self._timer = None - self._acc = None - self._n = None - self._elapsed = None + ) -> None: super(Frequency, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._timer = Timer() self._acc = 0 self._n = 0 @@ -55,13 +51,13 @@ def reset(self): super(Frequency, self).reset() @reinit__is_reduced - def update(self, output): + def update(self, output: int) -> None: self._acc += output self._n = self._acc self._elapsed = self._timer.value() @sync_all_reduce("_n", "_elapsed") - def compute(self): + def compute(self) -> float: time_divisor = 1.0 if idist.get_world_size() > 1: @@ -70,10 +66,13 @@ def compute(self): # Returns the average processed objects per second across all workers return self._n / self._elapsed * time_divisor - def completed(self, engine, name): + def completed(self, engine: Engine, name: str) -> None: engine.state.metrics[name] = int(self.compute()) - def attach(self, engine, name, event_name=Events.ITERATION_COMPLETED): + # TODO: see issue https://github.com/pytorch/ignite/issues/1405 + def attach( # type: ignore + self, engine: Engine, name: str, event_name: Events = Events.ITERATION_COMPLETED + ) -> None: engine.add_event_handler(Events.EPOCH_STARTED, self.started) engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) engine.add_event_handler(event_name, self.completed, name) diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 8fc3aaba3002..18fc34a36d87 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -1,4 +1,4 @@ -from typing import Callable, Sequence, Union +from typing import Callable, Dict, Sequence, Tuple, Union, cast import torch @@ -51,12 +51,12 @@ def reset(self) -> None: self._num_examples = 0 @reinit__is_reduced - def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None: + def update(self, output: Sequence[Union[torch.Tensor, Dict]]) -> None: if len(output) == 2: - y_pred, y = output - kwargs = {} + y_pred, y = cast(Tuple[torch.Tensor, torch.Tensor], output) + kwargs = {} # type: Dict else: - y_pred, y, kwargs = output + y_pred, y, kwargs = cast(Tuple[torch.Tensor, torch.Tensor, Dict], output) average_loss = self._loss_fn(y_pred.detach(), y.detach(), **kwargs) if len(average_loss.shape) != 0: @@ -67,7 +67,7 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None: self._num_examples += n @sync_all_reduce("_sum", "_num_examples") - def compute(self) -> None: + def compute(self) -> float: if self._num_examples == 0: raise NotComputableError("Loss must have at least one example before it can be computed.") return self._sum.item() / self._num_examples diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index 0473fc07ff50..bf39e3036b83 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -22,13 +22,13 @@ def __init__( eps: float = 1e-6, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), - ): + ) -> None: super(MeanPairwiseDistance, self).__init__(output_transform, device=device) self._p = p self._eps = eps @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._sum_of_distances = torch.tensor(0.0, device=self._device) self._num_examples = 0 diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index b85a39103b18..b076d5e2b8e2 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -2,12 +2,15 @@ from abc import ABCMeta, abstractmethod from collections.abc import Mapping from functools import wraps -from typing import Any, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union import torch import ignite.distributed as idist -from ignite.engine import Engine, Events +from ignite.engine import CallableEventWithFilter, Engine, Events + +if TYPE_CHECKING: + from ignite.metrics.metrics_lambda import MetricsLambda __all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"] @@ -28,21 +31,21 @@ class MetricUsage: :meth:`~ignite.metrics.Metric.iteration_completed`. """ - def __init__(self, started, completed, iteration_completed): + def __init__(self, started: Events, completed: Events, iteration_completed: CallableEventWithFilter) -> None: self.__started = started self.__completed = completed self.__iteration_completed = iteration_completed @property - def STARTED(self): + def STARTED(self) -> Events: return self.__started @property - def COMPLETED(self): + def COMPLETED(self) -> Events: return self.__completed @property - def ITERATION_COMPLETED(self): + def ITERATION_COMPLETED(self) -> CallableEventWithFilter: return self.__iteration_completed @@ -62,7 +65,7 @@ class EpochWise(MetricUsage): usage_name = "epoch_wise" - def __init__(self): + def __init__(self) -> None: super(EpochWise, self).__init__( started=Events.EPOCH_STARTED, completed=Events.EPOCH_COMPLETED, @@ -86,7 +89,7 @@ class BatchWise(MetricUsage): usage_name = "batch_wise" - def __init__(self): + def __init__(self) -> None: super(BatchWise, self).__init__( started=Events.ITERATION_STARTED, completed=Events.ITERATION_COMPLETED, @@ -111,7 +114,7 @@ class BatchFiltered(MetricUsage): """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(BatchFiltered, self).__init__( started=Events.EPOCH_STARTED, completed=Events.EPOCH_COMPLETED, @@ -191,7 +194,7 @@ def compute(self): """ # public class attribute - required_output_keys = ("y_pred", "y") + required_output_keys = ("y_pred", "y") # type: Optional[Tuple] # for backward compatibility _required_output_keys = required_output_keys @@ -212,10 +215,10 @@ def __init__( ) # Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it. - if torch.device(device).type == "xla": + if torch.device(device).type == "xla": # type: ignore[arg-type] raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.") - self._device = torch.device(device) + self._device = torch.device(device) # type: ignore[arg-type] self._is_reduced = False self.reset() @@ -229,7 +232,7 @@ def reset(self) -> None: pass @abstractmethod - def update(self, output) -> None: + def update(self, output: Any) -> None: """ Updates the metric's state using the passed batch output. @@ -420,72 +423,72 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise usage = self._check_usage(usage) return engine.has_event_handler(self.completed, usage.COMPLETED) - def __add__(self, other): + def __add__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x + y, self, other) - def __radd__(self, other): + def __radd__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x + y, other, self) - def __sub__(self, other): + def __sub__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x - y, self, other) - def __rsub__(self, other): + def __rsub__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x - y, other, self) - def __mul__(self, other): + def __mul__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x * y, self, other) - def __rmul__(self, other): + def __rmul__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x * y, other, self) - def __pow__(self, other): + def __pow__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x ** y, self, other) - def __rpow__(self, other): + def __rpow__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x ** y, other, self) - def __mod__(self, other): + def __mod__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x % y, self, other) - def __div__(self, other): + def __div__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), self, other) - def __rdiv__(self, other): + def __rdiv__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), other, self) - def __truediv__(self, other): + def __truediv__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), self, other) - def __rtruediv__(self, other): + def __rtruediv__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), other, self) - def __floordiv__(self, other): + def __floordiv__(self, other: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x // y, self, other) @@ -493,27 +496,27 @@ def __floordiv__(self, other): def __getattr__(self, attr: str) -> Callable: from ignite.metrics.metrics_lambda import MetricsLambda - def fn(x, *args, **kwargs): + def fn(x: Metric, *args: Any, **kwargs: Any) -> Any: return getattr(x, attr)(*args, **kwargs) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> "MetricsLambda": return MetricsLambda(fn, self, *args, **kwargs) return wrapper - def __getitem__(self, index: Any): + def __getitem__(self, index: Any) -> "MetricsLambda": from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x: x[index], self) - def __getstate__(self): + def __getstate__(self) -> Dict: return self.__dict__ - def __setstate__(self, d): + def __setstate__(self, d: Dict) -> None: self.__dict__.update(d) -def sync_all_reduce(*attrs) -> Callable: +def sync_all_reduce(*attrs: Any) -> Callable: """Helper decorator for distributed configuration to collect instance attribute value across all participating processes. @@ -526,7 +529,7 @@ def sync_all_reduce(*attrs) -> Callable: def wrapper(func: Callable) -> Callable: @wraps(func) - def another_wrapper(self: Metric, *args, **kwargs) -> Callable: + def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable: if not isinstance(self, Metric): raise RuntimeError( "Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only" @@ -547,7 +550,7 @@ def another_wrapper(self: Metric, *args, **kwargs) -> Callable: return another_wrapper - wrapper._decorated = True + setattr(wrapper, "_decorated", True) return wrapper @@ -559,9 +562,9 @@ def reinit__is_reduced(func: Callable) -> Callable: """ @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None: func(self, *args, **kwargs) self._is_reduced = False - wrapper._decorated = True + setattr(wrapper, "_decorated", True) return wrapper diff --git a/ignite/metrics/metrics_lambda.py b/ignite/metrics/metrics_lambda.py index 14348ebf51b1..6514b71cfd00 100644 --- a/ignite/metrics/metrics_lambda.py +++ b/ignite/metrics/metrics_lambda.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union import torch @@ -64,11 +64,11 @@ def Fbeta(r, p, beta): """ - def __init__(self, f: Callable, *args, **kwargs): + def __init__(self, f: Callable, *args: Any, **kwargs: Any) -> None: self.function = f self.args = args self.kwargs = kwargs - self.engine = None + self.engine = None # type: Optional[Engine] super(MetricsLambda, self).__init__(device="cpu") @reinit__is_reduced @@ -78,7 +78,7 @@ def reset(self) -> None: i.reset() @reinit__is_reduced - def update(self, output) -> None: + def update(self, output: Any) -> None: # NB: this method does not recursively update dependency metrics, # which might cause duplicate update issue. To update this metric, # users should manually update its dependencies. @@ -138,7 +138,7 @@ def _internal_is_attached(self, engine: Engine, usage: MetricUsage) -> bool: return not is_detached -def _get_value_on_cpu(v: Any): +def _get_value_on_cpu(v: Any) -> Any: if isinstance(v, Metric): v = v.compute() if isinstance(v, torch.Tensor): diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index f1d417ee015f..7301fb15f54c 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Sequence, Union +from typing import Callable, Sequence, Union, cast import torch @@ -30,8 +30,6 @@ def __init__( ) self._average = average - self._true_positives = None - self._positives = None self.eps = 1e-20 super(_BasePrecisionRecall, self).__init__( output_transform=output_transform, is_multilabel=is_multilabel, device=device @@ -39,19 +37,20 @@ def __init__( @reinit__is_reduced def reset(self) -> None: + self._true_positives = 0 # type: Union[int, torch.Tensor] + self._positives = 0 # type: Union[int, torch.Tensor] + if self._is_multilabel: init_value = 0.0 if self._average else [] - kws = {"dtype": torch.float64, "device": self._device} - self._true_positives = torch.tensor(init_value, **kws) - self._positives = torch.tensor(init_value, **kws) - else: - self._true_positives = 0 - self._positives = 0 + self._true_positives = torch.tensor(init_value, dtype=torch.float64, device=self._device) + self._positives = torch.tensor(init_value, dtype=torch.float64, device=self._device) super(_BasePrecisionRecall, self).reset() def compute(self) -> Union[torch.Tensor, float]: - is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 + is_scalar = ( + not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 # type: ignore[attr-defined] + ) if is_scalar and self._positives == 0: raise NotComputableError( "{} must have at least one example before it can be computed.".format(self.__class__.__name__) @@ -59,14 +58,14 @@ def compute(self) -> Union[torch.Tensor, float]: if not (self._type == "multilabel" and not self._average): if not self._is_reduced: - self._true_positives = idist.all_reduce(self._true_positives) - self._positives = idist.all_reduce(self._positives) - self._is_reduced = True + self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[arg-type, assignment] + self._positives = idist.all_reduce(self._positives) # type: ignore[arg-type, assignment] + self._is_reduced = True # type: bool result = self._true_positives / (self._positives + self.eps) if self._average: - return result.mean().item() + return cast(torch.Tensor, result).mean().item() else: return result @@ -177,8 +176,8 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if self._type == "multilabel": if not self._average: - self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) - self._positives = torch.cat([self._positives, all_positives], dim=0) + self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) # type: torch.Tensor + self._positives = torch.cat([self._positives, all_positives], dim=0) # type: torch.Tensor else: self._true_positives += torch.sum(true_positives / (all_positives + self.eps)) self._positives += len(all_positives) diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index ad391705a004..19a087febb64 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -115,8 +115,8 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if self._type == "multilabel": if not self._average: - self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) - self._positives = torch.cat([self._positives, actual_positives], dim=0) + self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) # type: torch.Tensor + self._positives = torch.cat([self._positives, actual_positives], dim=0) # type: torch.Tensor else: self._true_positives += torch.sum(true_positives / (actual_positives + self.eps)) self._positives += len(actual_positives) diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 419743ec15a5..db36b1a90483 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -1,10 +1,10 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union, cast import torch import ignite.distributed as idist from ignite.engine import Engine, Events -from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce __all__ = ["RunningAverage"] @@ -66,7 +66,7 @@ def __init__( raise ValueError("Argument device should be None if src is a Metric.") self.src = src self._get_src_value = self._get_metric_value - self.iteration_completed = self._metric_iteration_completed + setattr(self, "iteration_completed", self._metric_iteration_completed) device = src._device else: if output_transform is None: @@ -75,17 +75,17 @@ def __init__( "to the output of process function." ) self._get_src_value = self._get_output_value - self.update = self._output_update + setattr(self, "update", self._output_update) if device is None: device = torch.device("cpu") self.alpha = alpha self.epoch_bound = epoch_bound - super(RunningAverage, self).__init__(output_transform=output_transform, device=device) + super(RunningAverage, self).__init__(output_transform=output_transform, device=device) # type: ignore[arg-type] @reinit__is_reduced def reset(self) -> None: - self._value = None + self._value = None # type: Optional[Union[float, torch.Tensor]] @reinit__is_reduced def update(self, output: Sequence) -> None: @@ -100,7 +100,7 @@ def compute(self) -> Union[torch.Tensor, float]: return self._value - def attach(self, engine: Engine, name: str): + def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = EpochWise()) -> None: if self.epoch_bound: # restart average every epoch engine.add_event_handler(Events.EPOCH_STARTED, self.started) @@ -115,7 +115,7 @@ def _get_metric_value(self) -> Union[torch.Tensor, float]: @sync_all_reduce("src") def _get_output_value(self) -> Union[torch.Tensor, float]: # we need to compute average instead of sum produced by @sync_all_reduce("src") - output = self.src / idist.get_world_size() + output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size() return output def _metric_iteration_completed(self, engine: Engine) -> None: @@ -126,4 +126,4 @@ def _metric_iteration_completed(self, engine: Engine) -> None: def _output_update(self, output: Union[torch.Tensor, float]) -> None: if isinstance(output, torch.Tensor): output = output.detach().to(self._device, copy=True) - self.src = output + self.src = output # type: ignore[assignment] diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 256ce516b589..69a9e861f394 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -56,14 +56,14 @@ def __init__( device: Union[str, torch.device] = torch.device("cpu"), ): if isinstance(kernel_size, int): - self.kernel_size = [kernel_size, kernel_size] + self.kernel_size = [kernel_size, kernel_size] # type: Sequence[int] elif isinstance(kernel_size, Sequence): self.kernel_size = kernel_size else: raise ValueError("Argument kernel_size should be either int or a sequence of int.") if isinstance(sigma, float): - self.sigma = [sigma, sigma] + self.sigma = [sigma, sigma] # type: Sequence[float] elif isinstance(sigma, Sequence): self.sigma = sigma else: @@ -85,11 +85,12 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum_of_batchwise_ssim = 0.0 # Not a tensor because batch size is not known in advance. + # Not a tensor because batch size is not known in advance. + self._sum_of_batchwise_ssim = 0.0 # type: Union[float, torch.Tensor] self._num_examples = 0 self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) - def _uniform(self, kernel_size): + def _uniform(self, kernel_size: int) -> torch.Tensor: max, min = 2.5, -2.5 ksize_half = (kernel_size - 1) * 0.5 kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, device=self._device) @@ -101,13 +102,13 @@ def _uniform(self, kernel_size): return kernel.unsqueeze(dim=0) # (1, kernel_size) - def _gaussian(self, kernel_size, sigma): + def _gaussian(self, kernel_size: int, sigma: float) -> torch.Tensor: ksize_half = (kernel_size - 1) * 0.5 kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, device=self._device) gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - def _gaussian_or_uniform_kernel(self, kernel_size, sigma): + def _gaussian_or_uniform_kernel(self, kernel_size: Sequence[int], sigma: Sequence[float]) -> torch.Tensor: if self.gaussian: kernel_x = self._gaussian(kernel_size[0], sigma[0]) kernel_y = self._gaussian(kernel_size[1], sigma[1]) @@ -142,8 +143,8 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if len(self._kernel.shape) < 4: self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device) - y_pred = F.pad(y_pred, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect") - y = F.pad(y, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect") + y_pred = F.pad(y_pred, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") + y = F.pad(y, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") input_list = torch.cat([y_pred, y, y_pred * y_pred, y * y, y_pred * y]) outputs = F.conv2d(input_list, self._kernel, groups=channel) @@ -171,4 +172,4 @@ def update(self, output: Sequence[torch.Tensor]) -> None: def compute(self) -> torch.Tensor: if self._num_examples == 0: raise NotComputableError("SSIM must have at least one example before it can be computed.") - return torch.sum(self._sum_of_batchwise_ssim / self._num_examples) + return torch.sum(self._sum_of_batchwise_ssim / self._num_examples) # type: ignore[arg-type] diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index dad423f86af1..20d948003ebc 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -16,8 +16,11 @@ class TopKCategoricalAccuracy(Metric): """ def __init__( - self, k=5, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), - ): + self, + k: int = 5, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + ) -> None: super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device) self._k = k @@ -27,7 +30,7 @@ def reset(self) -> None: self._num_examples = 0 @reinit__is_reduced - def update(self, output: Sequence) -> None: + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] expanded_y = y.view(-1, 1).expand(-1, self._k) @@ -40,6 +43,6 @@ def update(self, output: Sequence) -> None: def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError( - "TopKCategoricalAccuracy must have at" "least one example before it can be computed." + "TopKCategoricalAccuracy must have at least one example before it can be computed." ) return self._num_correct.item() / self._num_examples diff --git a/mypy.ini b/mypy.ini index 33b53407df0c..699768e7680d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,10 +7,6 @@ show_error_codes = True ignore_errors = True -[mypy-ignite.metrics.*] - -ignore_errors = True - [mypy-horovod.*] ignore_missing_imports = True diff --git a/tests/ignite/metrics/test_accumulation.py b/tests/ignite/metrics/test_accumulation.py index adc7acac85d1..0be2a0e98e68 100644 --- a/tests/ignite/metrics/test_accumulation.py +++ b/tests/ignite/metrics/test_accumulation.py @@ -108,7 +108,7 @@ def test_geom_average(): mean_var.update(y.item()) m = mean_var.compute() - assert m.item() == pytest.approx(_geom_mean(y_true)) + assert m == pytest.approx(_geom_mean(y_true)) mean_var = GeometricAverage() y_true = torch.rand(100, 10) + torch.randint(0, 10, size=(100, 10)).float() @@ -293,7 +293,7 @@ def _test(metric_device): log_y_true = torch.log(y_true) log_y_true = idist.all_reduce(log_y_true) np.testing.assert_almost_equal( - m.item(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item(), decimal=decimal + m, torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item(), decimal=decimal ) mean_var = GeometricAverage(device=metric_device) diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index 8da41e07a411..8925b3905c84 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -39,7 +39,7 @@ def test_multiclass_wrong_inputs(): with pytest.raises(ValueError, match=r"Argument average can None or one of"): ConfusionMatrix(num_classes=10, average="abc") - with pytest.raises(ValueError, match=r"Argument average one of 'samples', 'recall', 'precision'"): + with pytest.raises(ValueError, match=r"Argument average should be one of 'samples', 'recall', 'precision'"): ConfusionMatrix.normalize(None, None)