Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
bb6f4e8
Implement feature and fix a few bugs
sadra-barikbin May 14, 2022
65c8261
autopep8 fix
sadra-barikbin May 14, 2022
c960f7b
Fix MyPy issues
sadra-barikbin May 14, 2022
9f9d323
Remove unused imports
sadra-barikbin May 14, 2022
aa3fc9b
Fix flake8 issue and some bugs
sadra-barikbin May 14, 2022
288258c
Fix affected metrics
sadra-barikbin May 14, 2022
15e57a4
autopep8 fix
sadra-barikbin May 14, 2022
2b41176
Empty commit
sadra-barikbin May 14, 2022
db91bf2
Fix docstring
sadra-barikbin May 14, 2022
2a9d59f
Fix average parameter docstring
sadra-barikbin May 16, 2022
d7f7f2d
autopep8 fix
sadra-barikbin May 16, 2022
1c66ae0
Fix bug
sadra-barikbin May 16, 2022
1bd436e
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin May 16, 2022
bebacd4
Fix bug and classification_report
sadra-barikbin May 16, 2022
fdf1842
Fix bug in doctests and tests
sadra-barikbin May 16, 2022
711c86a
Fix bug in doctests and tests
sadra-barikbin May 16, 2022
b311002
Make recall like precision, undo classification_report changes
sadra-barikbin May 21, 2022
a3c7f5a
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin May 21, 2022
93728fc
Resolve mypy and flake issues
sadra-barikbin May 21, 2022
adcc1c5
Undo change
sadra-barikbin May 21, 2022
6de29b0
Merge branch 'master' into improve-precision-recall-metric-issue-2571
sadra-barikbin May 25, 2022
3ae3cdc
Add more description to docstrings
sadra-barikbin May 26, 2022
9cd7c4c
autopep8 fix
sadra-barikbin May 26, 2022
5341a8d
empty commit
sadra-barikbin May 26, 2022
c9df80c
Improve code
sadra-barikbin May 26, 2022
48e06c1
Add 'macro' option
sadra-barikbin Jun 1, 2022
c2637aa
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin Jun 1, 2022
cfbb04a
Merge branch 'master' into improve-precision-recall-metric-issue-2571
sadra-barikbin Jun 1, 2022
405634d
Add None option to average parameter
sadra-barikbin Jun 2, 2022
e206852
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin Jun 2, 2022
8a562ce
Fix affected tests
sadra-barikbin Jun 3, 2022
214088b
Fix affected doctests
sadra-barikbin Jun 3, 2022
d450903
Do some refactors and improvements
sadra-barikbin Jun 6, 2022
485e4e4
Reduce internal vars to three
sadra-barikbin Jun 7, 2022
5c813c8
Fix a few bugs and do a few improvements
sadra-barikbin Jun 7, 2022
da3853d
Fix bugs, tests and do a few refactors
sadra-barikbin Jun 9, 2022
4e007b0
Fix bug in doctests
sadra-barikbin Jun 9, 2022
300fbb9
Fix mypy issue
sadra-barikbin Jun 9, 2022
0a1bd52
A little improvement
sadra-barikbin Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 137 additions & 99 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, cast, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand All @@ -15,7 +15,7 @@ class _BasePrecisionRecall(_BaseClassification):
def __init__(
self,
output_transform: Callable = lambda x: x,
average: bool = False,
average: Union[bool, str] = False,
is_multilabel: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
):
Expand All @@ -27,16 +27,29 @@ def __init__(
output_transform=output_transform, is_multilabel=is_multilabel, device=device
)

def _check_type(self, output: Sequence[torch.Tensor]) -> None:
super()._check_type(output)

if self._average == "micro" and self._type in ["binary", "multiclass"]:
raise ValueError(
"`Precision` and `Recall` with average=='micro' and binary or multiclass "
"input data are equivalent with `Accuracy`, so use this metric."
)
if self._type in ["binary", "multiclass"] and self._average == "samples":
raise ValueError("Average == 'samples' is incompatible with binary and multiclass input data.")

@reinit__is_reduced
def reset(self) -> None:
self._true_positives = 0 # type: Union[int, torch.Tensor]
self._positives = 0 # type: Union[int, torch.Tensor]
self._updated = False
if self._average == "samples":
self._sum_samples_metric = 0 # type: Union[int, torch.Tensor]
self._samples_cnt = 0 # type: int
else:
self._true_positives = 0 # type: Union[int, torch.Tensor]
self._positives = 0 # type: Union[int, torch.Tensor]

if self._is_multilabel:
init_value = 0.0 if self._average else []
self._true_positives = torch.tensor(init_value, dtype=torch.float64, device=self._device)
self._positives = torch.tensor(init_value, dtype=torch.float64, device=self._device)
if self._average == "weighted":
self._actual_positives = 0 # type: Union[int, torch.Tensor]
self._updated = False

super(_BasePrecisionRecall, self).reset()

Expand All @@ -46,24 +59,31 @@ def compute(self) -> Union[torch.Tensor, float]:
f"{self.__class__.__name__} must have at least one example before it can be computed."
)
if not self._is_reduced:
if not (self._type == "multilabel" and not self._average):
if self._average == "samples":
self._sum_samples_metric = idist.all_reduce(self._sum_samples_metric) # type: ignore[assignment]
self._samples_cnt = idist.all_reduce(self._samples_cnt) # type: ignore[assignment]
else:
self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[assignment]
self._positives = idist.all_reduce(self._positives) # type: ignore[assignment]
else:
self._true_positives = cast(torch.Tensor, idist.all_gather(self._true_positives))
self._positives = cast(torch.Tensor, idist.all_gather(self._positives))
if self._average == "weighted":
self._actual_positives = idist.all_reduce(self._actual_positives) # type: ignore[assignment]
self._is_reduced = True # type: bool

result = self._true_positives / (self._positives + self.eps)
if self._average == "samples":
return (self._sum_samples_metric / self._samples_cnt).item() # type: ignore

if self._average:
return cast(torch.Tensor, result).mean().item()
result = self._true_positives / (self._positives + self.eps)
if self._average == "weighted":
denominator = self._actual_positives.sum() + self.eps # type: ignore
return ((result @ self._actual_positives) / denominator).item() # type: ignore
elif self._average == "micro":
return result.item() # type: ignore
else:
return result


class Precision(_BasePrecisionRecall):
r"""Calculates precision for binary and multiclass data.
r"""Calculates precision for binary, multiclass and multilabel data.

.. math:: \text{Precision} = \frac{ TP }{ TP + FP }

Expand All @@ -78,10 +98,23 @@ class Precision(_BasePrecisionRecall):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
average: if True, precision is computed as the unweighted average (across all classes
in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
is_multilabel: flag to use in multilabel case. By default, value is False. If True, average
parameter should be True and the average is computed across samples, instead of classes.
average: available options are
``False``: default option. For multicalss and multilabel
inputs, per class and per label metric is returned. By calling `mean()` on the
metric instance, the `macro` setting (which is unweighted average across
classes or labels) is returned.
``'micro'``: for multilabel input, every label of each sample is considered itself
a sample then precision is computed. For binary and multiclass
inputs, this is equivalent with `Accuracy`, so use that metric.
``'samples'``: for multilabel input, at first, precision is computed
on a per sample basis and then average across samples is
returned. Incompatible with binary and multiclass inputs.
``'weighted'``: for binary and multiclass input, it computes metric for each class then
returns average of them weighted by support of classes (number of actual samples
in each class). For multilabel input, it computes precision for each label then
returns average of them weighted by support of labels (number of actual positive
samples in each label).
is_multilabel: flag to use in multilabel case. By default, value is False.
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
Expand All @@ -93,139 +126,137 @@ class Precision(_BasePrecisionRecall):
.. include:: defaults.rst
:start-after: :orphan:

Binary case
Binary case. In binary and multilabel cases, the elements of
`y` and `y_pred` should have 0 or 1 values.

.. testcode:: 1

metric = Precision(average=False)
metric = Precision()
weighted_metric = Precision(average='weighted')
metric.attach(default_evaluator, "precision")
weighted_metric.attach(default_evaluator, "weighted precision")
y_true = torch.Tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.Tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
print(f"Precision: {state.metrics["precision"]}")
print(f"Weighted Precision: {state.metrics["weighted precision"]}")

.. testoutput:: 1

0.75
Precision: 0.75
Weighted Precision: 0.6666666666666666

Multiclass case

.. testcode:: 2

metric = Precision(average=False)
metric = Precision()
macro_metric = metric.mean()
weighted_metric = Precision(average='weighted')

metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long()
macro_metric.attach(default_evaluator, "macro precision")
weighted_metric.attach(default_evaluator, "weighted precision")

y_true = torch.Tensor([2, 0, 2, 1, 0]).long()
y_pred = torch.Tensor([
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
[0.1785, 0.2670, 0.6084],
[0.8448, 0.7177, 0.7288],
[0.7748, 0.9542, 0.8573],
[0.8448, 0.7177, 0.7288]
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
print(f"Precision: {state.metrics["precision"]}")
print(f"Macro Precision: {state.metrics["macro precision"]}")
print(f"Weighted Precision: {state.metrics["weighted precision"]}")

.. testoutput:: 2

tensor([0.5000, 1.0000, 0.3333], dtype=torch.float64)
Precision: tensor([0.5000, 1.0000, 0.3333], dtype=torch.float64)
Macro Precision: 0.27777777777777773
Weighted Precision: 0.3333333333333333

Precision can be computed as the unweighted average across all classes:
Multilabel case, the shapes must be (batch_size, num_labels, ...)

.. testcode:: 3

metric = Precision(average=True)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long()
y_pred = torch.Tensor([
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
[0.1785, 0.2670, 0.6084],
[0.8448, 0.7177, 0.7288],
[0.7748, 0.9542, 0.8573],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])

.. testoutput:: 3

0.6111...

Multilabel case, the shapes must be (batch_size, num_categories, ...)

.. testcode:: 4

metric = Precision(is_multilabel=True)
micro_metric = Precision(is_multilabel=True, average='micro')
macro_metric = metric.mean()
weighted_metric = Precision(is_multilabel=True, average='weighted')
samples_metric = Precision(is_multilabel=True, average='samples')

metric.attach(default_evaluator, "precision")
micro_metric.attach(default_evaluator, "micro precision")
macro_metric.attach(default_evaluator, "macro precision")
weighted_metric.attach(default_evaluator, "weighted precision")
samples_metric.attach(default_evaluator, "samples precision")

y_true = torch.Tensor([
[0, 0, 1],
[0, 0, 0],
[0, 0, 0],
[1, 0, 0],
[0, 1, 1],
]).unsqueeze(0)
])
y_pred = torch.Tensor([
[1, 1, 0],
[1, 0, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
]).unsqueeze(0)
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
print(f"Micro Precision: {state.metrics["micro precision"]}")
print(f"Macro Precision: {state.metrics["macro precision"]}")
print(f"Weighted Precision: {state.metrics["weighted precision"]}")
print(f"Samples Precision: {state.metrics["samples precision"]}")

.. testoutput:: 4
.. testoutput:: 3

tensor([0.2000, 0.5000, 0.0000], dtype=torch.float64)
Precision: tensor([0.2000, 0.5000, 0.0000], dtype=torch.float64)
Micro Precision: 0.2222222222222222
Macro Precision: 0.2333333333333333
Weighted Precision: 0.175
Samples Precision: 0.2

In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of
predictions can be done as below:
Thresholding of predictions can be done as below:

.. testcode:: 5
.. testcode:: 4

def thresholded_output_transform(output):
y_pred, y = output
y_pred = torch.round(y_pred)
return y_pred, y

metric = Precision(average=False, output_transform=thresholded_output_transform)
metric = Precision(output_transform=thresholded_output_transform)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.Tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])

.. testoutput:: 5
.. testoutput:: 4

0.75

In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for
example, average parameter should be False. This can be done as shown below:

.. code-block:: python

precision = Precision(average=False)
recall = Recall(average=False)
F1 = precision * recall * 2 / (precision + recall + 1e-20)
F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)

.. warning::

In multilabel cases, if average is False, current implementation stores all input data (output and target) in
as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger
than available RAM.


.. versionchanged:: 0.5.0
`average` parameter's semantic changed and three options were added to it.
"""

def __init__(
self,
output_transform: Callable = lambda x: x,
average: bool = False,
average: Union[bool, str] = False,
is_multilabel: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
):

if average not in [False, "micro", "weighted", "samples"]:
raise ValueError("Argument average should be one of values " "False, 'micro', 'weighted' and 'samples'.")
super(Precision, self).__init__(
output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device
)
Expand All @@ -237,9 +268,15 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

if self._type == "binary":

y_pred = y_pred.view(-1)
y = y.view(-1)

if self._average == "weighted":
y = to_onehot(y, num_classes=2)
y_pred = to_onehot(y_pred.long(), num_classes=2)
elif self._type == "multiclass":

num_classes = y_pred.size(1)
if y.max() + 1 > num_classes:
raise ValueError(
Expand All @@ -250,31 +287,32 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
indices = torch.argmax(y_pred, dim=1).view(-1)
y_pred = to_onehot(indices, num_classes=num_classes)
elif self._type == "multilabel":
# if y, y_pred shape is (N, C, ...) -> (C, N x ...)
num_classes = y_pred.size(1)
y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
y = torch.transpose(y, 1, 0).reshape(num_classes, -1)

num_labels = y_pred.size(1)
y_pred = torch.transpose(y_pred, 1, -1).reshape(-1, num_labels)
y = torch.transpose(y, 1, -1).reshape(-1, num_labels)

# Convert from int cuda/cpu to double on self._device
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
y = y.to(dtype=torch.float64, device=self._device)
correct = y * y_pred
all_positives = y_pred.sum(dim=0)

if correct.sum() == 0:
true_positives = torch.zeros_like(all_positives)
else:
true_positives = correct.sum(dim=0)
if self._average == "samples":

if self._type == "multilabel":
if not self._average:
self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) # type: torch.Tensor
self._positives = torch.cat([self._positives, all_positives], dim=0) # type: torch.Tensor
else:
self._true_positives += torch.sum(true_positives / (all_positives + self.eps))
self._positives += len(all_positives)
else:
self._true_positives += true_positives
self._positives += all_positives
all_positives = y_pred.sum(dim=1)
true_positives = correct.sum(dim=1)
self._sum_samples_metric += torch.sum(true_positives / (all_positives + self.eps))
self._samples_cnt += y.size(0)
elif self._average == "micro":

self._positives += y_pred.sum()
self._true_positives += correct.sum()
else: # _average in [False, 'weighted']

self._positives += y_pred.sum(dim=0)
self._true_positives += correct.sum(dim=0)

if self._average == "weighted":
self._actual_positives += y.sum(dim=0)

self._updated = True
Loading