diff --git a/ignite/metrics/classification_report.py b/ignite/metrics/classification_report.py index 8c83156d5cc6..40809b0eef5c 100644 --- a/ignite/metrics/classification_report.py +++ b/ignite/metrics/classification_report.py @@ -4,7 +4,6 @@ import torch from ignite.metrics.fbeta import Fbeta -from ignite.metrics.metric import Metric from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.precision import Precision from ignite.metrics.recall import Recall @@ -85,14 +84,14 @@ def ClassificationReport( [0, 0, 0], [1, 0, 0], [0, 1, 1], - ]).unsqueeze(0) + ]) y_pred = torch.tensor([ [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], - ]).unsqueeze(0) + ]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["cr"].keys()) print(state.metrics["cr"]["0"]) @@ -119,25 +118,24 @@ def ClassificationReport( averaged_fbeta = fbeta.mean() def _wrapper( - recall_metric: Metric, precision_metric: Metric, f: Metric, a_recall: Metric, a_precision: Metric, a_f: Metric + re: torch.Tensor, pr: torch.Tensor, f: torch.Tensor, a_re: torch.Tensor, a_pr: torch.Tensor, a_f: torch.Tensor ) -> Union[Collection[str], Dict]: - p_tensor, r_tensor, f_tensor = precision_metric, recall_metric, f - if p_tensor.shape != r_tensor.shape: + if pr.shape != re.shape: raise ValueError( "Internal error: Precision and Recall have mismatched shapes: " - f"{p_tensor.shape} vs {r_tensor.shape}. Please, open an issue " + f"{pr.shape} vs {re.shape}. Please, open an issue " "with a reference on this error. Thank you!" ) dict_obj = {} - for idx, p_label in enumerate(p_tensor): + for idx, p_label in enumerate(pr): dict_obj[_get_label_for_class(idx)] = { "precision": p_label.item(), - "recall": r_tensor[idx].item(), - "f{0}-score".format(beta): f_tensor[idx].item(), + "recall": re[idx].item(), + "f{0}-score".format(beta): f[idx].item(), } dict_obj["macro avg"] = { - "precision": a_precision.item(), - "recall": a_recall.item(), + "precision": a_pr.item(), + "recall": a_re.item(), "f{0}-score".format(beta): a_f.item(), } return dict_obj if output_dict else json.dumps(dict_obj) diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index f5180d1736a6..3772521f1a01 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -1,4 +1,5 @@ -from typing import Callable, cast, Sequence, Union +import warnings +from typing import Callable, cast, Optional, Sequence, Union import torch @@ -15,55 +16,150 @@ class _BasePrecisionRecall(_BaseClassification): def __init__( self, output_transform: Callable = lambda x: x, - average: bool = False, + average: Optional[Union[bool, str]] = False, is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), ): - self._average = average + if not (average is None or isinstance(average, bool) or average in ["macro", "micro", "weighted", "samples"]): + raise ValueError( + "Argument average should be None or a boolean or one of values" + " 'macro', 'micro', 'weighted' and 'samples'." + ) + + if average is True: + self._average = "macro" # type: Optional[Union[bool, str]] + else: + self._average = average self.eps = 1e-20 self._updated = False super(_BasePrecisionRecall, self).__init__( output_transform=output_transform, is_multilabel=is_multilabel, device=device ) + def _check_type(self, output: Sequence[torch.Tensor]) -> None: + super()._check_type(output) + + if self._type in ["binary", "multiclass"] and self._average == "samples": + raise ValueError("Argument average='samples' is incompatible with binary and multiclass input data.") + + y_pred, y = output + if self._type == "multiclass" and y.dtype != torch.long: + warnings.warn("`y` should be of dtype long when entry type is multiclass", RuntimeWarning) + if ( + self._type == "binary" + and self._average is not False + and (y.dtype != torch.long or y_pred.dtype != torch.long) + ): + warnings.warn( + "`y` and `y_pred` should be of dtype long when entry type is binary and average!=False", RuntimeWarning + ) + + def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tensor]: + y_pred, y = output[0].detach(), output[1].detach() + + if self._type == "binary" or self._type == "multiclass": + + num_classes = 2 if self._type == "binary" else y_pred.size(1) + if self._type == "multiclass" and y.max() + 1 > num_classes: + raise ValueError( + f"y_pred contains less classes than y. Number of predicted classes is {num_classes}" + f" and element in y has invalid class = {y.max().item() + 1}." + ) + y = y.view(-1) + if self._type == "binary" and self._average is False: + y_pred = y_pred.view(-1) + else: + y = to_onehot(y.long(), num_classes=num_classes) + indices = torch.argmax(y_pred, dim=1) if self._type == "multiclass" else y_pred.long() + y_pred = to_onehot(indices.view(-1), num_classes=num_classes) + elif self._type == "multilabel": + # if y, y_pred shape is (N, C, ...) -> (N * ..., C) + num_labels = y_pred.size(1) + y_pred = torch.transpose(y_pred, 1, -1).reshape(-1, num_labels) + y = torch.transpose(y, 1, -1).reshape(-1, num_labels) + + # Convert from int cuda/cpu to double on self._device + y_pred = y_pred.to(dtype=torch.float64, device=self._device) + y = y.to(dtype=torch.float64, device=self._device) + correct = y * y_pred + + return y_pred, y, correct + @reinit__is_reduced def reset(self) -> None: - self._true_positives = 0 # type: Union[int, torch.Tensor] - self._positives = 0 # type: Union[int, torch.Tensor] - self._updated = False - if self._is_multilabel: - init_value = 0.0 if self._average else [] - self._true_positives = torch.tensor(init_value, dtype=torch.float64, device=self._device) - self._positives = torch.tensor(init_value, dtype=torch.float64, device=self._device) + # `numerator`, `denominator` and `weight` are three variables chosen to be abstract + # representatives of the ones that are measured for cases with different `average` parameters. + # `weight` is only used when `average='weighted'`. Actual value of these three variables is + # as follows. + # + # average='samples': + # numerator (torch.Tensor): sum of metric value for samples + # denominator (int): number of samples + # + # average='weighted': + # numerator (torch.Tensor): number of true positives per class/label + # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) + # positives per class/label + # weight (torch.Tensor): number of actual positives per class + # + # average='micro': + # numerator (torch.Tensor): sum of number of true positives for classes/labels + # denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives + # for classes/labels + # + # average='macro' or boolean or None: + # numerator (torch.Tensor): number of true positives per class/label + # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) + # positives per class/label + + self._numerator = 0 # type: Union[int, torch.Tensor] + self._denominator = 0 # type: Union[int, torch.Tensor] + self._weight = 0 # type: Union[int, torch.Tensor] + self._updated = False super(_BasePrecisionRecall, self).reset() def compute(self) -> Union[torch.Tensor, float]: + + # Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows. + # + # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight + # + # wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C` + # for the `macro` one. :math:`C` is the number of classes/labels. + # + # Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows. + # + # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } + if not self._updated: raise NotComputableError( f"{self.__class__.__name__} must have at least one example before it can be computed." ) if not self._is_reduced: - if not (self._type == "multilabel" and not self._average): - self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[assignment] - self._positives = idist.all_reduce(self._positives) # type: ignore[assignment] - else: - self._true_positives = cast(torch.Tensor, idist.all_gather(self._true_positives)) - self._positives = cast(torch.Tensor, idist.all_gather(self._positives)) + self._numerator = idist.all_reduce(self._numerator) # type: ignore[assignment] + self._denominator = idist.all_reduce(self._denominator) # type: ignore[assignment] + if self._average == "weighted": + self._weight = idist.all_reduce(self._weight) # type: ignore[assignment] self._is_reduced = True # type: bool - result = self._true_positives / (self._positives + self.eps) + fraction = self._numerator / (self._denominator + (self.eps if self._average != "samples" else 0)) - if self._average: - return cast(torch.Tensor, result).mean().item() + if self._average == "weighted": + sum_of_weights = cast(torch.Tensor, self._weight).sum() + self.eps + return ((fraction @ self._weight) / sum_of_weights).item() # type: ignore + elif self._average == "micro" or self._average == "samples": + return cast(torch.Tensor, fraction).item() + elif self._average == "macro": + return cast(torch.Tensor, fraction).mean().item() else: - return result + return fraction class Precision(_BasePrecisionRecall): - r"""Calculates precision for binary and multiclass data. + r"""Calculates precision for binary, multiclass and multilabel data. .. math:: \text{Precision} = \frac{ TP }{ TP + FP } @@ -78,10 +174,70 @@ class Precision(_BasePrecisionRecall): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - average: if True, precision is computed as the unweighted average (across all classes - in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). - is_multilabel: flag to use in multilabel case. By default, value is False. If True, average - parameter should be True and the average is computed across samples, instead of classes. + average: available options are + + False + default option. For multicalss and multilabel inputs, per class and per label + metric is returned respectively. + + None + like `False` option except that per class metric is returned for binary data as well. + For compatibility with Scikit-Learn api. + + 'micro' + Metric is computed counting stats of classes/labels altogether. + + .. math:: + \text{Micro Precision} = \frac{\sum_{k=1}^C TP_k}{\sum_{k=1}^C TP_k+FP_k} + + where :math:`C` is the number of classes/labels (2 in binary case). :math:`k` in :math:`TP_k` + and :math:`FP_k` means that the measures are computed for class/label :math:`k` (in a one-vs-rest + sense in multiclass case). + + For binary and multiclass inputs, this is equivalent with accuracy, + so use :class:`~ignite.metrics.accuracy.Accuracy`. + + 'samples' + for multilabel input, at first, precision is computed on a + per sample basis and then average across samples is returned. + + .. math:: + \text{Sample-averaged Precision} = \frac{\sum_{n=1}^N \frac{TP_n}{TP_n+FP_n}}{N} + + where :math:`N` is the number of samples. :math:`n` in :math:`TP_n` and :math:`FP_n` + means that the measures are computed for sample :math:`n`, across labels. + + Incompatible with binary and multiclass inputs. + + 'weighted' + like macro precision but considers class/label imbalance. for binary and multiclass + input, it computes metric for each class then returns average of them weighted by + support of classes (number of actual samples in each class). For multilabel input, + it computes precision for each label then returns average of them weighted by support + of labels (number of actual positive samples in each label). + + .. math:: + Precision_k = \frac{TP_k}{TP_k+FP_k} + + .. math:: + \text{Weighted Precision} = \frac{\sum_{k=1}^C P_k * Precision_k}{N} + + where :math:`C` is the number of classes (2 in binary case). :math:`P_k` is the number + of samples belonged to class :math:`k` in binary and multiclass case, and the number of + positive samples belonged to label :math:`k` in multilabel case. + + macro + computes macro precision which is unweighted average of metric computed across + classes/labels. + + .. math:: + \text{Macro Precision} = \frac{\sum_{k=1}^C Precision_k}{C} + + where :math:`C` is the number of classes (2 in binary case). + + True + like macro option. For backward compatibility. + is_multilabel: flag to use in multilabel case. By default, value is False. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. @@ -93,188 +249,153 @@ class Precision(_BasePrecisionRecall): .. include:: defaults.rst :start-after: :orphan: - Binary case + Binary case. In binary and multilabel cases, the elements of + `y` and `y_pred` should have 0 or 1 values. .. testcode:: 1 - metric = Precision(average=False) + metric = Precision() + weighted_metric = Precision(average='weighted') + two_class_metric = Precision(average=None) # Returns precision for both classes metric.attach(default_evaluator, "precision") - y_true = torch.Tensor([1, 0, 1, 1, 0, 1]) - y_pred = torch.Tensor([1, 0, 1, 0, 1, 1]) + weighted_metric.attach(default_evaluator, "weighted precision") + two_class_metric.attach(default_evaluator, "both classes precision") + y_true = torch.tensor([1, 0, 1, 1, 0, 1]) + y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["precision"]) + print(f"Precision: {state.metrics['precision']}") + print(f"Weighted Precision: {state.metrics['weighted precision']}") + print(f"Precision for class 0 and class 1: {state.metrics['both classes precision']}") .. testoutput:: 1 - 0.75 + Precision: 0.75 + Weighted Precision: 0.6666666666666666 + Precision for class 0 and class 1: tensor([0.5000, 0.7500], dtype=torch.float64) Multiclass case .. testcode:: 2 - metric = Precision(average=False) + metric = Precision() + macro_metric = Precision(average=True) + weighted_metric = Precision(average='weighted') + metric.attach(default_evaluator, "precision") - y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long() + macro_metric.attach(default_evaluator, "macro precision") + weighted_metric.attach(default_evaluator, "weighted precision") + + y_true = torch.tensor([2, 0, 2, 1, 0]) y_pred = torch.Tensor([ [0.0266, 0.1719, 0.3055], [0.6886, 0.3978, 0.8176], [0.9230, 0.0197, 0.8395], [0.1785, 0.2670, 0.6084], - [0.8448, 0.7177, 0.7288], - [0.7748, 0.9542, 0.8573], + [0.8448, 0.7177, 0.7288] ]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["precision"]) + print(f"Precision: {state.metrics['precision']}") + print(f"Macro Precision: {state.metrics['macro precision']}") + print(f"Weighted Precision: {state.metrics['weighted precision']}") .. testoutput:: 2 - tensor([0.5000, 1.0000, 0.3333], dtype=torch.float64) + Precision: tensor([0.5000, 0.0000, 0.3333], dtype=torch.float64) + Macro Precision: 0.27777777777777773 + Weighted Precision: 0.3333333333333333 - Precision can be computed as the unweighted average across all classes: + Multilabel case, the shapes must be (batch_size, num_labels, ...) .. testcode:: 3 - metric = Precision(average=True) - metric.attach(default_evaluator, "precision") - y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long() - y_pred = torch.Tensor([ - [0.0266, 0.1719, 0.3055], - [0.6886, 0.3978, 0.8176], - [0.9230, 0.0197, 0.8395], - [0.1785, 0.2670, 0.6084], - [0.8448, 0.7177, 0.7288], - [0.7748, 0.9542, 0.8573], - ]) - state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["precision"]) - - .. testoutput:: 3 - - 0.6111... - - Multilabel case, the shapes must be (batch_size, num_categories, ...) - - .. testcode:: 4 - metric = Precision(is_multilabel=True) + micro_metric = Precision(is_multilabel=True, average='micro') + macro_metric = Precision(is_multilabel=True, average=True) + weighted_metric = Precision(is_multilabel=True, average='weighted') + samples_metric = Precision(is_multilabel=True, average='samples') + metric.attach(default_evaluator, "precision") + micro_metric.attach(default_evaluator, "micro precision") + macro_metric.attach(default_evaluator, "macro precision") + weighted_metric.attach(default_evaluator, "weighted precision") + samples_metric.attach(default_evaluator, "samples precision") + y_true = torch.Tensor([ [0, 0, 1], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 1], - ]).unsqueeze(0) + ]) y_pred = torch.Tensor([ [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], - ]).unsqueeze(0) + ]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["precision"]) + print(f"Precision: {state.metrics['precision']}") + print(f"Micro Precision: {state.metrics['micro precision']}") + print(f"Macro Precision: {state.metrics['macro precision']}") + print(f"Weighted Precision: {state.metrics['weighted precision']}") + print(f"Samples Precision: {state.metrics['samples precision']}") - .. testoutput:: 4 + .. testoutput:: 3 - tensor([0.2000, 0.5000, 0.0000], dtype=torch.float64) + Precision: tensor([0.2000, 0.5000, 0.0000], dtype=torch.float64) + Micro Precision: 0.2222222222222222 + Macro Precision: 0.2333333333333333 + Weighted Precision: 0.175 + Samples Precision: 0.2 - In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of - predictions can be done as below: + Thresholding of predictions can be done as below: - .. testcode:: 5 + .. testcode:: 4 def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y - metric = Precision(average=False, output_transform=thresholded_output_transform) + metric = Precision(output_transform=thresholded_output_transform) metric.attach(default_evaluator, "precision") - y_true = torch.Tensor([1, 0, 1, 1, 0, 1]) + y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.Tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["precision"]) - .. testoutput:: 5 + .. testoutput:: 4 0.75 - In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for - example, average parameter should be False. This can be done as shown below: - - .. code-block:: python - - precision = Precision(average=False) - recall = Recall(average=False) - F1 = precision * recall * 2 / (precision + recall + 1e-20) - F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1) - - .. warning:: - - In multilabel cases, if average is False, current implementation stores all input data (output and target) in - as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger - than available RAM. - + .. versionchanged:: 0.5.0 + Some new options were added to `average` parameter. """ - def __init__( - self, - output_transform: Callable = lambda x: x, - average: bool = False, - is_multilabel: bool = False, - device: Union[str, torch.device] = torch.device("cpu"), - ): - super(Precision, self).__init__( - output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device - ) - @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) self._check_type(output) - y_pred, y = output[0].detach(), output[1].detach() + y_pred, y, correct = self._prepare_output(output) - if self._type == "binary": - y_pred = y_pred.view(-1) - y = y.view(-1) - elif self._type == "multiclass": - num_classes = y_pred.size(1) - if y.max() + 1 > num_classes: - raise ValueError( - f"y_pred contains less classes than y. Number of predicted classes is {num_classes}" - f" and element in y has invalid class = {y.max().item() + 1}." - ) - y = to_onehot(y.view(-1), num_classes=num_classes) - indices = torch.argmax(y_pred, dim=1).view(-1) - y_pred = to_onehot(indices, num_classes=num_classes) - elif self._type == "multilabel": - # if y, y_pred shape is (N, C, ...) -> (C, N x ...) - num_classes = y_pred.size(1) - y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) - y = torch.transpose(y, 1, 0).reshape(num_classes, -1) + if self._average == "samples": - # Convert from int cuda/cpu to double on self._device - y_pred = y_pred.to(dtype=torch.float64, device=self._device) - y = y.to(dtype=torch.float64, device=self._device) - correct = y * y_pred - all_positives = y_pred.sum(dim=0) + all_positives = y_pred.sum(dim=1) + true_positives = correct.sum(dim=1) + self._numerator += torch.sum(true_positives / (all_positives + self.eps)) + self._denominator += y.size(0) + elif self._average == "micro": - if correct.sum() == 0: - true_positives = torch.zeros_like(all_positives) - else: - true_positives = correct.sum(dim=0) + self._denominator += y_pred.sum() + self._numerator += correct.sum() + else: # _average in [False, None, 'macro', 'weighted'] - if self._type == "multilabel": - if not self._average: - self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) # type: torch.Tensor - self._positives = torch.cat([self._positives, all_positives], dim=0) # type: torch.Tensor - else: - self._true_positives += torch.sum(true_positives / (all_positives + self.eps)) - self._positives += len(all_positives) - else: - self._true_positives += true_positives - self._positives += all_positives + self._denominator += y_pred.sum(dim=0) + self._numerator += correct.sum(dim=0) + + if self._average == "weighted": + self._weight += y.sum(dim=0) self._updated = True diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index f566b9848d2f..c18762696020 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -1,16 +1,15 @@ -from typing import Callable, Sequence, Union +from typing import Sequence import torch from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.precision import _BasePrecisionRecall -from ignite.utils import to_onehot __all__ = ["Recall"] class Recall(_BasePrecisionRecall): - r"""Calculates recall for binary and multiclass data. + r"""Calculates recall for binary, multiclass and multilabel data. .. math:: \text{Recall} = \frac{ TP }{ TP + FN } @@ -25,10 +24,73 @@ class Recall(_BasePrecisionRecall): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - average: if True, precision is computed as the unweighted average (across all classes - in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). - is_multilabel: flag to use in multilabel case. By default, value is False. If True, average - parameter should be True and the average is computed across samples, instead of classes. + average: available options are + + False + default option. For multicalss and multilabel inputs, per class and per label + metric is returned respectively. + + None + like `False` option except that per class metric is returned for binary data as well. + For compatibility with Scikit-Learn api. + + 'micro' + Metric is computed counting stats of classes/labels altogether. + + .. math:: + \text{Micro Recall} = \frac{\sum_{k=1}^C TP_k}{\sum_{k=1}^C TP_k+FN_k} + + where :math:`C` is the number of classes/labels (2 in binary case). :math:`k` in + :math:`TP_k` and :math:`FN_k`means that the measures are computed for class/label :math:`k` (in + a one-vs-rest sense in multiclass case). + + For binary and multiclass inputs, this is equivalent with accuracy, + so use :class:`~ignite.metrics.accuracy.Accuracy`. + + 'samples' + for multilabel input, at first, recall is computed on a + per sample basis and then average across samples is returned. + + .. math:: + \text{Sample-averaged Recall} = \frac{\sum_{n=1}^N \frac{TP_n}{TP_n+FN_n}}{N} + + where :math:`N` is the number of samples. :math:`n` in :math:`TP_n` and :math:`FN_n` + means that the measures are computed for sample :math:`n`, across labels. + + Incompatible with binary and multiclass inputs. + + 'weighted' + like macro recall but considers class/label imbalance. For binary and multiclass + input, it computes metric for each class then returns average of them weighted by + support of classes (number of actual samples in each class). For multilabel input, + it computes recall for each label then returns average of them weighted by support + of labels (number of actual positive samples in each label). + + .. math:: + Recall_k = \frac{TP_k}{TP_k+FN_k} + + .. math:: + \text{Weighted Recall} = \frac{\sum_{k=1}^C P_k * Recall_k}{N} + + where :math:`C` is the number of classes (2 in binary case). :math:`P_k` is the number + of samples belonged to class :math:`k` in binary and multiclass case, and the number of + positive samples belonged to label :math:`k` in multilabel case. + + Note that for binary and multiclass data, weighted recall is equivalent + with accuracy, so use :class:`~ignite.metrics.accuracy.Accuracy`. + + macro + computes macro recall which is unweighted average of metric computed across + classes or labels. + + .. math:: + \text{Macro Recall} = \frac{\sum_{k=1}^C Recall_k}{C} + + where :math:`C` is the number of classes (2 in binary case). + + True + like macro option. For backward compatibility. + is_multilabel: flag to use in multilabel case. By default, value is False. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. @@ -40,186 +102,141 @@ class Recall(_BasePrecisionRecall): .. include:: defaults.rst :start-after: :orphan: - Binary case + Binary case. In binary and multilabel cases, the elements of + `y` and `y_pred` should have 0 or 1 values. .. testcode:: 1 - metric = Recall(average=False) + metric = Recall() + two_class_metric = Recall(average=None) # Returns recall for both classes metric.attach(default_evaluator, "recall") + two_class_metric.attach(default_evaluator, "both classes recall") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["recall"]) + print(f"Recall: {state.metrics['recall']}") + print(f"Recall for class 0 and class 1: {state.metrics['both classes recall']}") .. testoutput:: 1 - 0.75 + Recall: 0.75 + Recall for class 0 and class 1: tensor([0.5000, 0.7500], dtype=torch.float64) Multiclass case .. testcode:: 2 - metric = Recall(average=False) - metric.attach(default_evaluator, "recall") - y_true = torch.tensor([2, 0, 2, 1, 0, 1]) - y_pred = torch.tensor([ - [0.0266, 0.1719, 0.3055], - [0.6886, 0.3978, 0.8176], - [0.9230, 0.0197, 0.8395], - [0.1785, 0.2670, 0.6084], - [0.8448, 0.7177, 0.7288], - [0.7748, 0.9542, 0.8573], - ]) - state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["recall"]) - - .. testoutput:: 2 - - tensor([0.5000, 0.5000, 0.5000], dtype=torch.float64) - - Precision can be computed as the unweighted average across all classes: + metric = Recall() + macro_metric = Recall(average=True) - .. testcode:: 3 - - metric = Recall(average=True) metric.attach(default_evaluator, "recall") - y_true = torch.tensor([2, 0, 2, 1, 0, 1]) + macro_metric.attach(default_evaluator, "macro recall") + + y_true = torch.tensor([2, 0, 2, 1, 0]) y_pred = torch.tensor([ [0.0266, 0.1719, 0.3055], [0.6886, 0.3978, 0.8176], [0.9230, 0.0197, 0.8395], [0.1785, 0.2670, 0.6084], - [0.8448, 0.7177, 0.7288], - [0.7748, 0.9542, 0.8573], + [0.8448, 0.7177, 0.7288] ]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["recall"]) + print(f"Recall: {state.metrics['recall']}") + print(f"Macro Recall: {state.metrics['macro recall']}") - .. testoutput:: 3 + .. testoutput:: 2 - 0.5 + Recall: tensor([0.5000, 0.0000, 0.5000], dtype=torch.float64) + Macro Recall: 0.3333333333333333 Multilabel case, the shapes must be (batch_size, num_categories, ...) - .. testcode:: 4 + .. testcode:: 3 metric = Recall(is_multilabel=True) + micro_metric = Recall(is_multilabel=True, average='micro') + macro_metric = Recall(is_multilabel=True, average=True) + samples_metric = Recall(is_multilabel=True, average='samples') + metric.attach(default_evaluator, "recall") + micro_metric.attach(default_evaluator, "micro recall") + macro_metric.attach(default_evaluator, "macro recall") + samples_metric.attach(default_evaluator, "samples recall") + y_true = torch.tensor([ [0, 0, 1], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 1], - ]).unsqueeze(0) + ]) y_pred = torch.tensor([ [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], - ]).unsqueeze(0) + ]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["recall"]) + print(f"Recall: {state.metrics['recall']}") + print(f"Micro Recall: {state.metrics['micro recall']}") + print(f"Macro Recall: {state.metrics['macro recall']}") + print(f"Samples Recall: {state.metrics['samples recall']}") - .. testoutput:: 4 + .. testoutput:: 3 - tensor([1., 1., 0.], dtype=torch.float64) + Recall: tensor([1., 1., 0.], dtype=torch.float64) + Micro Recall: 0.5 + Macro Recall: 0.6666666666666666 + Samples Recall: 0.3 - In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of - predictions can be done as below: + Thresholding of predictions can be done as below: - .. testcode:: 5 + .. testcode:: 4 def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y - metric = Recall(average=False, output_transform=thresholded_output_transform) + metric = Recall(output_transform=thresholded_output_transform) metric.attach(default_evaluator, "recall") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65]) state = default_evaluator.run([[y_pred, y_true]]) - print(state.metrics["recall"]) + print(state.metrics['recall']) - .. testoutput:: 5 + .. testoutput:: 4 0.75 - In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for - example, average parameter should be False. This can be done as shown below: - - .. code-block:: python - - precision = Precision(average=False) - recall = Recall(average=False) - F1 = precision * recall * 2 / (precision + recall + 1e-20) - F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1) - .. warning:: - - In multilabel cases, if average is False, current implementation stores all input data (output and target) in - as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger - than available RAM. + .. versionchanged:: 0.5.0 + Some new options were added to `average` parameter. """ - def __init__( - self, - output_transform: Callable = lambda x: x, - average: bool = False, - is_multilabel: bool = False, - device: Union[str, torch.device] = torch.device("cpu"), - ): - super(Recall, self).__init__( - output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device - ) - @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) self._check_type(output) - y_pred, y = output[0].detach(), output[1].detach() - - if self._type == "binary": - y_pred = y_pred.view(-1) - y = y.view(-1) - elif self._type == "multiclass": - num_classes = y_pred.size(1) - if y.max() + 1 > num_classes: - raise ValueError( - f"y_pred contains less classes than y. Number of predicted classes is {num_classes}" - f" and element in y has invalid class = {y.max().item() + 1}." - ) - y = to_onehot(y.view(-1), num_classes=num_classes) - indices = torch.argmax(y_pred, dim=1).view(-1) - y_pred = to_onehot(indices, num_classes=num_classes) - elif self._type == "multilabel": - # if y, y_pred shape is (N, C, ...) -> (C, N x ...) - num_classes = y_pred.size(1) - y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) - y = torch.transpose(y, 1, 0).reshape(num_classes, -1) - - # Convert from int cuda/cpu to double on self._device - y_pred = y_pred.to(dtype=torch.float64, device=self._device) - y = y.to(dtype=torch.float64, device=self._device) - correct = y * y_pred - actual_positives = y.sum(dim=0) - - if correct.sum() == 0: - true_positives = torch.zeros_like(actual_positives) - else: - true_positives = correct.sum(dim=0) - - if self._type == "multilabel": - if not self._average: - self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) # type: torch.Tensor - self._positives = torch.cat([self._positives, actual_positives], dim=0) # type: torch.Tensor - else: - self._true_positives += torch.sum(true_positives / (actual_positives + self.eps)) - self._positives += len(actual_positives) - else: - self._true_positives += true_positives - self._positives += actual_positives + _, y, correct = self._prepare_output(output) + + if self._average == "samples": + + actual_positives = y.sum(dim=1) + true_positives = correct.sum(dim=1) + self._numerator += torch.sum(true_positives / (actual_positives + self.eps)) + self._denominator += y.size(0) + elif self._average == "micro": + + self._denominator += y.sum() + self._numerator += correct.sum() + else: # _average in [False, 'macro', 'weighted'] + + self._denominator += y.sum(dim=0) + self._numerator += correct.sum(dim=0) + + if self._average == "weighted": + self._weight += y.sum(dim=0) self._updated = True diff --git a/tests/ignite/metrics/test_classification_report.py b/tests/ignite/metrics/test_classification_report.py index 74a2d9532115..e1ea1d7c5900 100644 --- a/tests/ignite/metrics/test_classification_report.py +++ b/tests/ignite/metrics/test_classification_report.py @@ -141,7 +141,6 @@ def update(engine, i): _test(metric_device, 2, ["0", "1", "2", "3", "4", "5", "6"]) -@pytest.mark.xfail @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -154,7 +153,6 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl): _test_integration_multilabel(device, False) -@pytest.mark.xfail @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_gloo_cpu_or_gpu(local_rank, distributed_context_single_node_gloo): @@ -166,7 +164,6 @@ def test_distrib_gloo_cpu_or_gpu(local_rank, distributed_context_single_node_glo _test_integration_multilabel(device, False) -@pytest.mark.xfail @pytest.mark.distributed @pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") @@ -190,7 +187,6 @@ def _test_distrib_xla_nprocs(index): _test_integration_multilabel(device, False) -@pytest.mark.xfail @pytest.mark.tpu @pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") @pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") @@ -207,7 +203,6 @@ def to_numpy_multilabel(y): return y -@pytest.mark.xfail @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") @@ -220,7 +215,6 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): _test_integration_multilabel(device, False) -@pytest.mark.xfail @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") diff --git a/tests/ignite/metrics/test_fbeta.py b/tests/ignite/metrics/test_fbeta.py index 27eb28905189..a44e0f408d2d 100644 --- a/tests/ignite/metrics/test_fbeta.py +++ b/tests/ignite/metrics/test_fbeta.py @@ -18,11 +18,11 @@ def test_wrong_inputs(): Fbeta(0.0) with pytest.raises(ValueError, match=r"Input precision metric should have average=False"): - p = Precision(average=True) + p = Precision(average="micro") Fbeta(1.0, precision=p) with pytest.raises(ValueError, match=r"Input recall metric should have average=False"): - r = Recall(average=True) + r = Recall(average="samples") Fbeta(1.0, recall=r) with pytest.raises(ValueError, match=r"If precision argument is provided, output_transform should be None"): diff --git a/tests/ignite/metrics/test_metrics_lambda.py b/tests/ignite/metrics/test_metrics_lambda.py index 72fce39885c6..9bfe8cada754 100644 --- a/tests/ignite/metrics/test_metrics_lambda.py +++ b/tests/ignite/metrics/test_metrics_lambda.py @@ -195,8 +195,8 @@ def Fbeta(r, p, beta): positives = all_positives1 + all_positives2 assert precision._type == "binary" - assert precision._true_positives == true_positives - assert precision._positives == positives + assert precision._numerator == true_positives + assert precision._denominator == positives # Computing positivies for recall is different positives1 = y1.sum(dim=0) @@ -204,8 +204,8 @@ def Fbeta(r, p, beta): positives = positives1 + positives2 assert recall._type == "binary" - assert recall._true_positives == true_positives - assert recall._positives == positives + assert recall._numerator == true_positives + assert recall._denominator == positives """ Test compute diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 5f07d5413526..659d4ccf4d56 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -20,11 +20,27 @@ def test_no_update(): precision.compute() assert precision._updated is False - precision = Precision(is_multilabel=True, average=True) - assert precision._updated is False - with pytest.raises(NotComputableError, match=r"Precision must have at least one example before it can be computed"): - precision.compute() - assert precision._updated is False + +def test_average_parameter(): + with pytest.raises(ValueError, match="Argument average should be None or a boolean or one of values"): + Precision(average=1) + + pr = Precision(average="samples") + with pytest.raises( + ValueError, match=r"Argument average='samples' is incompatible with binary and multiclass input data." + ): + pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long())) + assert pr._updated is False + + pr = Precision(average="samples") + with pytest.raises( + ValueError, match=r"Argument average='samples' is incompatible with binary and multiclass input data." + ): + pr.update((torch.rand(10, 3), torch.randint(0, 3, size=(10,)).long())) + assert pr._updated is False + + pr = Precision(average=True) + assert pr._average == "macro" def test_binary_wrong_inputs(): @@ -56,8 +72,36 @@ def test_binary_wrong_inputs(): pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) assert pr._updated is False + with pytest.warns( + RuntimeWarning, + match="`y` and `y_pred` should be of dtype long when entry type is binary and average!=False", + ): + pr = Precision(average=None) + pr.update((torch.randint(0, 2, size=(10,)).float(), torch.randint(0, 2, size=(10,)))) + + with pytest.warns( + RuntimeWarning, + match="`y` and `y_pred` should be of dtype long when entry type is binary and average!=False", + ): + pr = Precision(average=None) + pr.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)).float())) + + +def ignite_average_to_scikit_average(average, data_type: str): + if average in [None, "micro", "samples", "weighted", "macro"]: + return average + if average is False: + if data_type == "binary": + return "binary" + else: + return None + elif average is True: + return "macro" + else: + raise ValueError(f"Wrong average parameter `{average}`") + -@pytest.mark.parametrize("average", [False, True]) +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) def test_binary_input(average): pr = Precision(average=average) @@ -80,9 +124,10 @@ def _test(y_pred, y, batch_size): assert pr._type == "binary" assert pr._updated is True - assert isinstance(pr.compute(), float if average else torch.Tensor) - pr_compute = pr.compute() if average else pr.compute().numpy() - assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute) + assert isinstance(pr.compute(), torch.Tensor if not average else float) + pr_compute = pr.compute().numpy() if not average else pr.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "binary") + assert precision_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(pr_compute) def get_test_cases(): @@ -116,7 +161,7 @@ def get_test_cases(): # check multiple random inputs as random exact occurencies are rare test_cases = get_test_cases() for y_pred, y, batch_size in test_cases: - _test(y, y_pred, batch_size) + _test(y_pred, y, batch_size) def test_multiclass_wrong_inputs(): @@ -168,8 +213,15 @@ def test_multiclass_wrong_inputs(): pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) assert pr._updated is True + with pytest.warns( + RuntimeWarning, + match="`y` should be of dtype long when entry type is multiclass", + ): + pr = Precision() + pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).float())) + -@pytest.mark.parametrize("average", [False, True]) +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) def test_multiclass_input(average): pr = Precision(average=average) @@ -193,9 +245,9 @@ def _test(y_pred, y, batch_size): assert pr._type == "multiclass" assert pr._updated is True - assert isinstance(pr.compute(), float if average else torch.Tensor) - pr_compute = pr.compute() if average else pr.compute().numpy() - sk_average_parameter = "macro" if average else None + assert isinstance(pr.compute(), torch.Tensor if not average else float) + pr_compute = pr.compute().numpy() if not average else pr.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) @@ -237,7 +289,7 @@ def get_test_cases(): def test_multilabel_wrong_inputs(): - pr = Precision(average=True, is_multilabel=True) + pr = Precision(is_multilabel=True) assert pr._updated is False with pytest.raises(ValueError): @@ -270,7 +322,7 @@ def to_numpy_multilabel(y): return y -@pytest.mark.parametrize("average", [False, True]) +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted", "samples"]) def test_multilabel_input(average): pr = Precision(average=average, is_multilabel=True) @@ -293,22 +345,11 @@ def _test(y_pred, y, batch_size): assert pr._type == "multilabel" assert pr._updated is True - pr_compute = pr.compute() if average else pr.compute().mean().item() + pr_compute = pr.compute().numpy() if not average else pr.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) - assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute) - - pr1 = Precision(is_multilabel=True, average=True) - pr2 = Precision(is_multilabel=True, average=False) - assert pr1._updated is False - assert pr2._updated is False - pr1.update((y_pred, y)) - pr2.update((y_pred, y)) - assert pr1._updated is True - assert pr2._updated is True - assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) - assert pr1._updated is True - assert pr2._updated is True + assert precision_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(pr_compute) def get_test_cases(): @@ -325,7 +366,7 @@ def get_test_cases(): # updated batches (torch.randint(0, 2, size=(50, 5, 10)), torch.randint(0, 2, size=(50, 5, 10)), 16), (torch.randint(0, 2, size=(50, 4, 10)), torch.randint(0, 2, size=(50, 4, 10)), 16), - # Multilabel input data of shape (N, H, W, ...) and (N, C, H, W, ...) + # Multilabel input data of shape (N, C, H, W) (torch.randint(0, 2, size=(10, 5, 18, 16)), torch.randint(0, 2, size=(10, 5, 18, 16)), 1), (torch.randint(0, 2, size=(10, 4, 20, 23)), torch.randint(0, 2, size=(10, 4, 20, 23)), 1), # updated batches @@ -345,58 +386,40 @@ def get_test_cases(): _test(y_pred, y, batch_size) -def test_incorrect_type(): +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) +def test_incorrect_type(average): # Tests changing of type during training - def _test(average): - pr = Precision(average=average) - assert pr._updated is False - - y_pred = torch.softmax(torch.rand(4, 4), dim=1) - y = torch.ones(4).long() - pr.update((y_pred, y)) - assert pr._updated is True - - y_pred = torch.randint(0, 2, size=(4,)) - y = torch.ones(4).long() - - with pytest.raises(RuntimeError): - pr.update((y_pred, y)) + pr = Precision(average=average) + assert pr._updated is False - assert pr._updated is True + y_pred = torch.softmax(torch.rand(4, 4), dim=1) + y = torch.ones(4).long() + pr.update((y_pred, y)) + assert pr._updated is True - _test(average=True) - _test(average=False) + y_pred = torch.randint(0, 2, size=(4,)) + y = torch.ones(4).long() - pr1 = Precision(is_multilabel=True, average=True) - pr2 = Precision(is_multilabel=True, average=False) - assert pr1._updated is False - assert pr2._updated is False - y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) - y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() - pr1.update((y_pred, y)) - pr2.update((y_pred, y)) - assert pr1._updated is True - assert pr2._updated is True - assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) + with pytest.raises(RuntimeError): + pr.update((y_pred, y)) + assert pr._updated is True -def test_incorrect_y_classes(): - def _test(average): - pr = Precision(average=average) - assert pr._updated is False +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) +def test_incorrect_y_classes(average): + pr = Precision(average=average) - y_pred = torch.randint(0, 2, size=(10, 4)).float() - y = torch.randint(4, 5, size=(10,)).long() + assert pr._updated is False - with pytest.raises(ValueError): - pr.update((y_pred, y)) + y_pred = torch.randint(0, 2, size=(10, 4)).float() + y = torch.randint(4, 5, size=(10,)).long() - assert pr._updated is False + with pytest.raises(ValueError): + pr.update((y_pred, y)) - _test(average=True) - _test(average=False) + assert pr._updated is False def _test_distrib_integration_multiclass(device): @@ -437,8 +460,9 @@ def update(engine, i): assert res.device.type == "cpu" res = res.cpu().numpy() + sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") true_res = precision_score( - y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), average="macro" if average else None + y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), average=sk_average_parameter ) assert pytest.approx(res) == true_res @@ -448,10 +472,14 @@ def update(engine, i): metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: - _test(average=True, n_epochs=1, metric_device=metric_device) - _test(average=True, n_epochs=2, metric_device=metric_device) _test(average=False, n_epochs=1, metric_device=metric_device) _test(average=False, n_epochs=2, metric_device=metric_device) + _test(average="macro", n_epochs=1, metric_device=metric_device) + _test(average="macro", n_epochs=2, metric_device=metric_device) + _test(average="weighted", n_epochs=1, metric_device=metric_device) + _test(average="weighted", n_epochs=2, metric_device=metric_device) + _test(average="micro", n_epochs=1, metric_device=metric_device) + _test(average="micro", n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -499,32 +527,26 @@ def update(engine, i): np_y_preds = to_numpy_multilabel(y_preds) np_y_true = to_numpy_multilabel(y_true) assert pr._type == "multilabel" - res = res if average else res.mean().item() + sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) - assert precision_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res) + assert precision_score(np_y_true, np_y_preds, average=sk_average_parameter) == pytest.approx(res) metric_devices = ["cpu"] if device.type != "xla": metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: - _test(average=True, n_epochs=1, metric_device=metric_device) - _test(average=True, n_epochs=2, metric_device=metric_device) _test(average=False, n_epochs=1, metric_device=metric_device) _test(average=False, n_epochs=2, metric_device=metric_device) - - pr1 = Precision(is_multilabel=True, average=True) - pr2 = Precision(is_multilabel=True, average=False) - assert pr1._updated is False - assert pr2._updated is False - y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) - y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() - pr1.update((y_pred, y)) - pr2.update((y_pred, y)) - assert pr1._updated is True - assert pr2._updated is True - assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) + _test(average="macro", n_epochs=1, metric_device=metric_device) + _test(average="macro", n_epochs=2, metric_device=metric_device) + _test(average="micro", n_epochs=1, metric_device=metric_device) + _test(average="micro", n_epochs=2, metric_device=metric_device) + _test(average="weighted", n_epochs=1, metric_device=metric_device) + _test(average="weighted", n_epochs=2, metric_device=metric_device) + _test(average="samples", n_epochs=1, metric_device=metric_device) + _test(average="samples", n_epochs=2, metric_device=metric_device) def _test_distrib_accumulator_device(device): @@ -542,19 +564,29 @@ def _test(average, metric_device): pr.update((y_pred, y)) assert pr._updated is True + assert ( - pr._true_positives.device == metric_device - ), f"{type(pr._true_positives.device)}:{pr._true_positives.device} vs {type(metric_device)}:{metric_device}" - assert ( - pr._positives.device == metric_device - ), f"{type(pr._positives.device)}:{pr._positives.device} vs {type(metric_device)}:{metric_device}" + pr._numerator.device == metric_device + ), f"{type(pr._numerator.device)}:{pr._numerator.device} vs {type(metric_device)}:{metric_device}" + + if average != "samples": + # For average='samples', `_denominator` is of type `int` so it has not `device` member. + assert ( + pr._denominator.device == metric_device + ), f"{type(pr._denominator.device)}:{pr._denominator.device} vs {type(metric_device)}:{metric_device}" + + if average == "weighted": + assert pr._weight.device == metric_device, f"{type(pr._weight.device)}:{pr._weight.device} vs " + f"{type(metric_device)}:{metric_device}" metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: - _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) + _test("macro", metric_device=metric_device) + _test("micro", metric_device=metric_device) + _test("weighted", metric_device=metric_device) def _test_distrib_multilabel_accumulator_device(device): @@ -565,31 +597,36 @@ def _test(average, metric_device): assert pr._updated is False assert pr._device == metric_device - assert ( - pr._true_positives.device == metric_device - ), f"{type(pr._true_positives.device)}:{pr._true_positives.device} vs {type(metric_device)}:{metric_device}" - assert ( - pr._positives.device == metric_device - ), f"{type(pr._positives.device)}:{pr._positives.device} vs {type(metric_device)}:{metric_device}" y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() pr.update((y_pred, y)) assert pr._updated is True + assert ( - pr._true_positives.device == metric_device - ), f"{type(pr._true_positives.device)}:{pr._true_positives.device} vs {type(metric_device)}:{metric_device}" - assert ( - pr._positives.device == metric_device - ), f"{type(pr._positives.device)}:{pr._positives.device} vs {type(metric_device)}:{metric_device}" + pr._numerator.device == metric_device + ), f"{type(pr._numerator.device)}:{pr._numerator.device} vs {type(metric_device)}:{metric_device}" + + if average != "samples": + # For average='samples', `_denominator` is of type `int` so it has not `device` member. + assert ( + pr._denominator.device == metric_device + ), f"{type(pr._denominator.device)}:{pr._denominator.device} vs {type(metric_device)}:{metric_device}" + + if average == "weighted": + assert pr._weight.device == metric_device, f"{type(pr._weight.device)}:{pr._weight.device} vs " + f"{type(metric_device)}:{metric_device}" metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: - _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) + _test("macro", metric_device=metric_device) + _test("micro", metric_device=metric_device) + _test("weighted", metric_device=metric_device) + _test("samples", metric_device=metric_device) @pytest.mark.distributed diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 12c47e9f27e2..766d19dc5157 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -20,13 +20,33 @@ def test_no_update(): recall.compute() assert recall._updated is False - recall = Recall(is_multilabel=True, average=True) + recall = Recall(is_multilabel=True) assert recall._updated is False with pytest.raises(NotComputableError, match=r"Recall must have at least one example before it can be computed"): recall.compute() assert recall._updated is False +def test_average_parameter(): + + re = Recall(average="samples") + with pytest.raises( + ValueError, match=r"Argument average='samples' is incompatible with binary and multiclass input data." + ): + re.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long())) + assert re._updated is False + + re = Recall(average="samples") + with pytest.raises( + ValueError, match=r"Argument average='samples' is incompatible with binary and multiclass input data." + ): + re.update((torch.rand(10, 3), torch.randint(0, 3, size=(10,)).long())) + assert re._updated is False + + re = Recall(average=True) + assert re._average == "macro" + + def test_binary_wrong_inputs(): re = Recall() @@ -56,8 +76,36 @@ def test_binary_wrong_inputs(): re.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10, 5, 6)).long())) assert re._updated is False + with pytest.warns( + RuntimeWarning, + match="`y` and `y_pred` should be of dtype long when entry type is binary and average!=False", + ): + re = Recall(average=None) + re.update((torch.randint(0, 2, size=(10,)).float(), torch.randint(0, 2, size=(10,)))) + + with pytest.warns( + RuntimeWarning, + match="`y` and `y_pred` should be of dtype long when entry type is binary and average!=False", + ): + re = Recall(average=None) + re.update((torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)).float())) + + +def ignite_average_to_scikit_average(average, data_type: str): + if average in [None, "micro", "samples", "weighted", "macro"]: + return average + if average is False: + if data_type == "binary": + return "binary" + else: + return None + elif average is True: + return "macro" + else: + raise ValueError(f"Wrong average parameter `{average}`") -@pytest.mark.parametrize("average", [False, True]) + +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) def test_binary_input(average): re = Recall(average=average) @@ -80,9 +128,10 @@ def _test(y_pred, y, batch_size): assert re._type == "binary" assert re._updated is True - assert isinstance(re.compute(), float if average else torch.Tensor) - re_compute = re.compute() if average else re.compute().numpy() - assert recall_score(np_y, np_y_pred, average="binary") == pytest.approx(re_compute) + assert isinstance(re.compute(), torch.Tensor if not average else float) + re_compute = re.compute().numpy() if not average else re.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "binary") + assert recall_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(re_compute) def get_test_cases(): @@ -168,8 +217,15 @@ def test_multiclass_wrong_inputs(): re.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long())) assert re._updated is True + with pytest.warns( + RuntimeWarning, + match="`y` should be of dtype long when entry type is multiclass", + ): + re = Recall() + re.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).float())) + -@pytest.mark.parametrize("average", [False, True]) +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) def test_multiclass_input(average): re = Recall(average=average) @@ -193,9 +249,9 @@ def _test(y_pred, y, batch_size): assert re._type == "multiclass" assert re._updated is True - assert isinstance(re.compute(), float if average else torch.Tensor) - re_compute = re.compute() if average else re.compute().numpy() - sk_average_parameter = "macro" if average else None + assert isinstance(re.compute(), torch.Tensor if not average else float) + re_compute = re.compute().numpy() if not average else re.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) sk_compute = recall_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) @@ -237,7 +293,7 @@ def get_test_cases(): def test_multilabel_wrong_inputs(): - re = Recall(average=True, is_multilabel=True) + re = Recall(is_multilabel=True) assert re._updated is False with pytest.raises(ValueError): @@ -270,7 +326,7 @@ def to_numpy_multilabel(y): return y -@pytest.mark.parametrize("average", [False, True]) +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "samples"]) def test_multilabel_input(average): re = Recall(average=average, is_multilabel=True) @@ -293,22 +349,11 @@ def _test(y_pred, y, batch_size): assert re._type == "multilabel" assert re._updated is True - re_compute = re.compute() if average else re.compute().mean().item() + re_compute = re.compute().numpy() if not average else re.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) - assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute) - - re1 = Recall(is_multilabel=True, average=True) - re2 = Recall(is_multilabel=True, average=False) - assert re1._updated is False - assert re2._updated is False - re1.update((y_pred, y)) - re2.update((y_pred, y)) - assert re1._updated is True - assert re2._updated is True - assert re1.compute() == pytest.approx(re2.compute().mean().item()) - assert re1._updated is True - assert re2._updated is True + assert recall_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(re_compute) def get_test_cases(): @@ -345,58 +390,40 @@ def get_test_cases(): _test(y_pred, y, batch_size) -def test_incorrect_type(): +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) +def test_incorrect_type(average): # Tests changing of type during training - def _test(average): - re = Recall(average=average) - assert re._updated is False - - y_pred = torch.softmax(torch.rand(4, 4), dim=1) - y = torch.ones(4).long() - re.update((y_pred, y)) - assert re._updated is True - - y_pred = torch.zeros(4) - y = torch.ones(4).long() - - with pytest.raises(RuntimeError): - re.update((y_pred, y)) + re = Recall(average=average) + assert re._updated is False - assert re._updated is True + y_pred = torch.softmax(torch.rand(4, 4), dim=1) + y = torch.ones(4).long() + re.update((y_pred, y)) + assert re._updated is True - _test(average=True) - _test(average=False) + y_pred = torch.zeros(4) + y = torch.ones(4).long() - re1 = Recall(is_multilabel=True, average=True) - re2 = Recall(is_multilabel=True, average=False) - assert re1._updated is False - assert re2._updated is False - y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) - y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() - re1.update((y_pred, y)) - re2.update((y_pred, y)) - assert re1._updated is True - assert re2._updated is True - assert re1.compute() == pytest.approx(re2.compute().mean().item()) + with pytest.raises(RuntimeError): + re.update((y_pred, y)) + assert re._updated is True -def test_incorrect_y_classes(): - def _test(average): - re = Recall(average=average) - assert re._updated is False +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) +def test_incorrect_y_classes(average): + re = Recall(average=average) - y_pred = torch.randint(0, 2, size=(10, 4)).float() - y = torch.randint(4, 5, size=(10,)).long() + assert re._updated is False - with pytest.raises(ValueError): - re.update((y_pred, y)) + y_pred = torch.randint(0, 2, size=(10, 4)).float() + y = torch.randint(4, 5, size=(10,)).long() - assert re._updated is False + with pytest.raises(ValueError): + re.update((y_pred, y)) - _test(average=True) - _test(average=False) + assert re._updated is False def _test_distrib_integration_multiclass(device): @@ -438,8 +465,9 @@ def update(engine, i): assert res.device.type == "cpu" res = res.cpu().numpy() + sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") true_res = recall_score( - y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), average="macro" if average else None + y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), average=sk_average_parameter ) assert pytest.approx(res) == true_res @@ -449,10 +477,14 @@ def update(engine, i): metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: - _test(average=True, n_epochs=1, metric_device=metric_device) - _test(average=True, n_epochs=2, metric_device=metric_device) _test(average=False, n_epochs=1, metric_device=metric_device) _test(average=False, n_epochs=2, metric_device=metric_device) + _test(average="macro", n_epochs=1, metric_device=metric_device) + _test(average="macro", n_epochs=2, metric_device=metric_device) + _test(average="weighted", n_epochs=1, metric_device=metric_device) + _test(average="weighted", n_epochs=2, metric_device=metric_device) + _test(average="micro", n_epochs=1, metric_device=metric_device) + _test(average="micro", n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -500,32 +532,26 @@ def update(engine, i): np_y_preds = to_numpy_multilabel(y_preds) np_y_true = to_numpy_multilabel(y_true) assert re._type == "multilabel" - res = res if average else res.mean().item() + sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) - assert recall_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res) + assert recall_score(np_y_true, np_y_preds, average=sk_average_parameter) == pytest.approx(res) metric_devices = ["cpu"] if device.type != "xla": metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: - _test(average=True, n_epochs=1, metric_device=metric_device) - _test(average=True, n_epochs=2, metric_device=metric_device) _test(average=False, n_epochs=1, metric_device=metric_device) _test(average=False, n_epochs=2, metric_device=metric_device) - - re1 = Recall(is_multilabel=True, average=True) - re2 = Recall(is_multilabel=True, average=False) - assert re1._updated is False - assert re2._updated is False - y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) - y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() - re1.update((y_pred, y)) - re2.update((y_pred, y)) - assert re1._updated is True - assert re2._updated is True - assert re1.compute() == pytest.approx(re2.compute().mean().item()) + _test(average="macro", n_epochs=1, metric_device=metric_device) + _test(average="macro", n_epochs=2, metric_device=metric_device) + _test(average="micro", n_epochs=1, metric_device=metric_device) + _test(average="micro", n_epochs=2, metric_device=metric_device) + _test(average="weighted", n_epochs=1, metric_device=metric_device) + _test(average="weighted", n_epochs=2, metric_device=metric_device) + _test(average="samples", n_epochs=1, metric_device=metric_device) + _test(average="samples", n_epochs=2, metric_device=metric_device) def _test_distrib_accumulator_device(device): @@ -543,19 +569,29 @@ def _test(average, metric_device): re.update((y_reed, y)) assert re._updated is True + assert ( - re._true_positives.device == metric_device - ), f"{type(re._true_positives.device)}:{re._true_positives.device} vs {type(metric_device)}:{metric_device}" - assert ( - re._positives.device == metric_device - ), f"{type(re._positives.device)}:{re._positives.device} vs {type(metric_device)}:{metric_device}" + re._numerator.device == metric_device + ), f"{type(re._numerator.device)}:{re._numerator.device} vs {type(metric_device)}:{metric_device}" + + if average != "samples": + # For average='samples', `_denominator` is of type `int` so it has not `device` member. + assert ( + re._denominator.device == metric_device + ), f"{type(re._denominator.device)}:{re._denominator.device} vs {type(metric_device)}:{metric_device}" + + if average == "weighted": + assert re._weight.device == metric_device, f"{type(re._weight.device)}:{re._weight.device} vs " + f"{type(metric_device)}:{metric_device}" metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: - _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) + _test("macro", metric_device=metric_device) + _test("micro", metric_device=metric_device) + _test("weighted", metric_device=metric_device) def _test_distrib_multilabel_accumulator_device(device): @@ -566,31 +602,36 @@ def _test(average, metric_device): assert re._updated is False assert re._device == metric_device - assert ( - re._true_positives.device == metric_device - ), f"{type(re._true_positives.device)}:{re._true_positives.device} vs {type(metric_device)}:{metric_device}" - assert ( - re._positives.device == metric_device - ), f"{type(re._positives.device)}:{re._positives.device} vs {type(metric_device)}:{metric_device}" y_reed = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() re.update((y_reed, y)) assert re._updated is True + assert ( - re._true_positives.device == metric_device - ), f"{type(re._true_positives.device)}:{re._true_positives.device} vs {type(metric_device)}:{metric_device}" - assert ( - re._positives.device == metric_device - ), f"{type(re._positives.device)}:{re._positives.device} vs {type(metric_device)}:{metric_device}" + re._numerator.device == metric_device + ), f"{type(re._numerator.device)}:{re._numerator.device} vs {type(metric_device)}:{metric_device}" + + if average != "samples": + # For average='samples', `_denominator` is of type `int` so it has not `device` member. + assert ( + re._denominator.device == metric_device + ), f"{type(re._denominator.device)}:{re._denominator.device} vs {type(metric_device)}:{metric_device}" + + if average == "weighted": + assert re._weight.device == metric_device, f"{type(re._weight.device)}:{re._weight.device} vs " + f"{type(metric_device)}:{metric_device}" metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: - _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) + _test("macro", metric_device=metric_device) + _test("micro", metric_device=metric_device) + _test("weighted", metric_device=metric_device) + _test("samples", metric_device=metric_device) @pytest.mark.distributed