Skip to content

Activate mypy in ignite.metrics #1391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, Callable, Union
from typing import Any, Callable, Tuple, Union, cast

import torch

Expand Down Expand Up @@ -47,16 +47,15 @@ def __init__(
):
if not callable(op):
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
self.accumulator = None
self.num_examples = None

self._op = op

super(VariableAccumulation, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self.num_examples = torch.tensor(0, dtype=torch.long, device=self._device)
self.num_examples = 0

def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)):
Expand All @@ -73,14 +72,14 @@ def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:

self.accumulator = self._op(self.accumulator, output)

if hasattr(output, "shape"):
if isinstance(output, torch.Tensor):
self.num_examples += output.shape[0] if len(output.shape) > 1 else 1
else:
self.num_examples += 1

@sync_all_reduce("accumulator", "num_examples")
def compute(self) -> list:
return [self.accumulator, self.num_examples]
def compute(self) -> Tuple[torch.Tensor, int]:
return self.accumulator, self.num_examples


class Average(VariableAccumulation):
Expand Down Expand Up @@ -125,18 +124,18 @@ class Average(VariableAccumulation):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _mean_op(a, x):
if isinstance(x, torch.Tensor) and x.ndim > 1:
def _mean_op(a: Union[float, torch.Tensor], x: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if isinstance(x, torch.Tensor) and x.ndim > 1: # type: ignore[attr-defined]
x = x.sum(dim=0)
return a + x

super(Average, self).__init__(op=_mean_op, output_transform=output_transform, device=device)

@sync_all_reduce("accumulator", "num_examples")
def compute(self) -> Union[Any, torch.Tensor, numbers.Number]:
def compute(self) -> Union[torch.Tensor, numbers.Number]:
if self.num_examples < 1:
raise NotComputableError(
"{} must have at least one example before" " it can be computed.".format(self.__class__.__name__)
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
)

return self.accumulator / self.num_examples
Expand Down Expand Up @@ -173,21 +172,26 @@ class GeometricAverage(VariableAccumulation):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor:
def _geom_op(a: torch.Tensor, x: Union[numbers.Number, torch.Tensor]) -> torch.Tensor:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
x = torch.log(x)
if x.ndim > 1:
if x.ndim > 1: # type: ignore[attr-defined]
x = x.sum(dim=0)
return a + x

super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform, device=device)

@sync_all_reduce("accumulator", "num_examples")
def compute(self) -> torch.Tensor:
def compute(self) -> Union[torch.Tensor, numbers.Number]:
if self.num_examples < 1:
raise NotComputableError(
"{} must have at least one example before" " it can be computed.".format(self.__class__.__name__)
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
)

return torch.exp(self.accumulator / self.num_examples)
tensor = torch.exp(self.accumulator / self.num_examples)

if tensor.numel() == 1:
return cast(numbers.Number, tensor.item())

return tensor
12 changes: 5 additions & 7 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Sequence, Union
from typing import Callable, Optional, Sequence, Tuple, Union

import torch

Expand All @@ -16,8 +16,8 @@ def __init__(
device: Union[str, torch.device] = torch.device("cpu"),
):
self._is_multilabel = is_multilabel
self._type = None
self._num_classes = None
self._type = None # type: Optional[str]
self._num_classes = None # type: Optional[int]
super(_BaseClassification, self).__init__(output_transform=output_transform, device=device)

def reset(self) -> None:
Expand All @@ -35,7 +35,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
)

y_shape = y.shape
y_pred_shape = y_pred.shape
y_pred_shape = y_pred.shape # type: Tuple[int, ...]

if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
Expand Down Expand Up @@ -134,8 +134,6 @@ def __init__(
is_multilabel: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._num_correct = None
self._num_examples = None
super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device)

@reinit__is_reduced
Expand Down Expand Up @@ -167,7 +165,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
def compute(self) -> torch.Tensor:
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("Accuracy must have at least one example before it can be computed.")
return self._num_correct.item() / self._num_examples
28 changes: 14 additions & 14 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Optional, Sequence, Tuple, Union

import torch

Expand Down Expand Up @@ -54,7 +54,6 @@ def __init__(
self.num_classes = num_classes
self._num_examples = 0
self.average = average
self.confusion_matrix = None
super(ConfusionMatrix, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
Expand All @@ -67,7 +66,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:

if y_pred.ndimension() < 2:
raise ValueError(
"y_pred must have shape (batch_size, num_categories, ...), " "but given {}".format(y_pred.shape)
"y_pred must have shape (batch_size, num_categories, ...), but given {}".format(y_pred.shape)
)

if y_pred.shape[1] != self.num_classes:
Expand All @@ -83,7 +82,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
)

y_shape = y.shape
y_pred_shape = y_pred.shape
y_pred_shape = y_pred.shape # type: Tuple[int, ...]

if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
Expand Down Expand Up @@ -124,13 +123,12 @@ def compute(self) -> torch.Tensor:

@staticmethod
def normalize(matrix: torch.Tensor, average: str) -> torch.Tensor:
if average not in ("recall", "precision"):
raise ValueError("Argument average one of 'samples', 'recall', 'precision'")

if average == "recall":
return matrix / (matrix.sum(dim=1).unsqueeze(1) + 1e-15)
elif average == "precision":
return matrix / (matrix.sum(dim=0) + 1e-15)
else:
raise ValueError("Argument average should be one of 'samples', 'recall', 'precision'")


def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
Expand Down Expand Up @@ -170,14 +168,15 @@ def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambd
cm = cm.type(torch.DoubleTensor)
iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15)
if ignore_index is not None:
ignore_idx = ignore_index # type: int # used due to typing issues with mympy

def ignore_index_fn(iou_vector):
if ignore_index >= len(iou_vector):
def ignore_index_fn(iou_vector: torch.Tensor) -> torch.Tensor:
if ignore_idx >= len(iou_vector):
raise ValueError(
"ignore_index {} is larger than the length of IoU vector {}".format(ignore_index, len(iou_vector))
"ignore_index {} is larger than the length of IoU vector {}".format(ignore_idx, len(iou_vector))
)
indices = list(range(len(iou_vector)))
indices.remove(ignore_index)
indices.remove(ignore_idx)
return iou_vector[indices]

return MetricsLambda(ignore_index_fn, iou)
Expand Down Expand Up @@ -282,14 +281,15 @@ def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) ->
dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15)

if ignore_index is not None:
ignore_idx = ignore_index # type: int # used due to typing issues with mympy

def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor:
if ignore_index >= len(dice_vector):
if ignore_idx >= len(dice_vector):
raise ValueError(
"ignore_index {} is larger than the length of Dice vector {}".format(ignore_index, len(dice_vector))
"ignore_index {} is larger than the length of Dice vector {}".format(ignore_idx, len(dice_vector))
)
indices = list(range(len(dice_vector)))
indices.remove(ignore_index)
indices.remove(ignore_idx)
return dice_vector[indices]

return MetricsLambda(ignore_index_fn, dice)
Expand Down
24 changes: 11 additions & 13 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Sequence, Union
from typing import Callable, List, Sequence, Union, cast

import torch

Expand Down Expand Up @@ -54,24 +54,22 @@ def __init__(
output_transform: Callable = lambda x: x,
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
):
) -> None:

if not callable(compute_fn):
raise TypeError("Argument compute_fn should be callable.")

self._predictions = None
self._targets = None
self.compute_fn = compute_fn
self._check_compute_fn = check_compute_fn

super(EpochMetric, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._predictions = []
self._targets = []
self._predictions = [] # type: List[torch.Tensor]
self._targets = [] # type: List[torch.Tensor]

def _check_shape(self, output):
def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y_pred.ndimension() not in (1, 2):
raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).")
Expand All @@ -83,7 +81,7 @@ def _check_shape(self, output):
if not torch.equal(y ** 2, y):
raise ValueError("Targets should be binary (0 or 1).")

def _check_type(self, output):
def _check_type(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if len(self._predictions) < 1:
return
Expand All @@ -97,7 +95,7 @@ def _check_type(self, output):
dtype_targets = self._targets[-1].dtype
if dtype_targets != y.dtype:
raise ValueError(
"Incoherent types between input y and stored targets: " "{} vs {}".format(dtype_targets, y.dtype)
"Incoherent types between input y and stored targets: {} vs {}".format(dtype_targets, y.dtype)
)

@reinit__is_reduced
Expand Down Expand Up @@ -125,7 +123,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
except Exception as e:
warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), EpochMetricWarning)

def compute(self) -> None:
def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")

Expand All @@ -136,8 +134,8 @@ def compute(self) -> None:

if ws > 1 and not self._is_reduced:
# All gather across all processes
_prediction_tensor = idist.all_gather(_prediction_tensor)
_target_tensor = idist.all_gather(_target_tensor)
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
self._is_reduced = True

result = 0.0
Expand All @@ -147,7 +145,7 @@ def compute(self) -> None:

if ws > 1:
# broadcast result to all processes
result = idist.broadcast(result, src=0)
result = cast(float, idist.broadcast(result, src=0)) # type: ignore[arg-type]

return result

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def Fbeta(

if precision is None:
precision = Precision(
output_transform=(lambda x: x) if output_transform is None else output_transform,
output_transform=(lambda x: x) if output_transform is None else output_transform, # type: ignore[arg-type]
average=False,
device=device,
)
Expand All @@ -55,7 +55,7 @@ def Fbeta(

if recall is None:
recall = Recall(
output_transform=(lambda x: x) if output_transform is None else output_transform,
output_transform=(lambda x: x) if output_transform is None else output_transform, # type: ignore[arg-type]
average=False,
device=device,
)
Expand Down
25 changes: 12 additions & 13 deletions ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Callable, Optional, Union
from typing import Any, Callable, Union

import torch

import ignite.distributed as idist
from ignite.engine import Events
from ignite.engine import Engine, Events
from ignite.handlers.timing import Timer
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce


class Frequency(Metric):
Expand Down Expand Up @@ -39,29 +39,25 @@ class Frequency(Metric):

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
self._timer = None
self._acc = None
self._n = None
self._elapsed = None
) -> None:
super(Frequency, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self):
def reset(self) -> None:
self._timer = Timer()
self._acc = 0
self._n = 0
self._elapsed = 0.0
super(Frequency, self).reset()

@reinit__is_reduced
def update(self, output):
def update(self, output: int) -> None:
self._acc += output
self._n = self._acc
self._elapsed = self._timer.value()

@sync_all_reduce("_n", "_elapsed")
def compute(self):
def compute(self) -> float:
time_divisor = 1.0

if idist.get_world_size() > 1:
Expand All @@ -70,10 +66,13 @@ def compute(self):
# Returns the average processed objects per second across all workers
return self._n / self._elapsed * time_divisor

def completed(self, engine, name):
def completed(self, engine: Engine, name: str) -> None:
engine.state.metrics[name] = int(self.compute())

def attach(self, engine, name, event_name=Events.ITERATION_COMPLETED):
# TODO: see issue https://github.com/pytorch/ignite/issues/1405
def attach( # type: ignore
self, engine: Engine, name: str, event_name: Events = Events.ITERATION_COMPLETED
) -> None:
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
engine.add_event_handler(event_name, self.completed, name)
Loading