Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e51550d
adds available_device to test_entropy.py #3335
BanzaiTokyo Mar 27, 2025
9410e68
fix test_case producing large values
BanzaiTokyo Mar 27, 2025
158498c
adds available_device to test_fbeta.py #3335
BanzaiTokyo Mar 27, 2025
819576d
pass available_device to Precision and Recall
BanzaiTokyo Mar 28, 2025
cc10d87
convert y_true and y_pred to PyTorch tensors
BanzaiTokyo Mar 28, 2025
2cd05f5
Merge branch 'master' into test_fbeta_available_device
BanzaiTokyo Mar 28, 2025
13761c9
correctly transfer tensors to device
BanzaiTokyo Mar 28, 2025
3da78a9
adds available_device to test_fbeta.py #3335
BanzaiTokyo Mar 27, 2025
b382daa
pass available_device to Precision and Recall
BanzaiTokyo Mar 28, 2025
7839ca8
convert y_true and y_pred to PyTorch tensors
BanzaiTokyo Mar 28, 2025
41a8297
changes the fixture for Precision and Recall
BanzaiTokyo Mar 28, 2025
ced2c25
Merge remote-tracking branch 'origin/test_fbeta_available_device' int…
BanzaiTokyo Mar 28, 2025
7e1aff9
removes check for used device
BanzaiTokyo Mar 28, 2025
8c0a97d
Merge branch 'master' into test_fbeta_available_device
vfdev-5 Mar 28, 2025
b4b4375
takes into account PR comments
BanzaiTokyo Mar 29, 2025
4440d0e
updates the Non checks in the Fbeta constructor
BanzaiTokyo Mar 29, 2025
be54733
updates the error check for received an invalid combination of arguments
BanzaiTokyo Mar 29, 2025
2b8e22d
Merge branch 'master' into test_fbeta_available_device
BanzaiTokyo Mar 29, 2025
0e551c8
apply PR comments
BanzaiTokyo Mar 30, 2025
221ecd7
Merge branch 'master' into test_fbeta_available_device
BanzaiTokyo Apr 14, 2025
27c55b3
Merge branch 'master' into test_fbeta_available_device
BanzaiTokyo Apr 15, 2025
3c85f9d
adds missed optional in typing
BanzaiTokyo Apr 15, 2025
2ede234
conversion to float32
BanzaiTokyo Apr 15, 2025
68b9e72
fixes typing in fbeta.py
BanzaiTokyo Apr 15, 2025
d2431e9
adds check for: If precision argument is provided, device should be None
BanzaiTokyo Apr 15, 2025
0a069f8
pytest.approx istead of np.testing.assert_allclose
BanzaiTokyo Apr 16, 2025
ce980ee
Merge branch 'master' into test_fbeta_available_device
BanzaiTokyo Apr 16, 2025
078ff60
explicitely detach y_true and y_pred to cpu before passing them to fb…
BanzaiTokyo Apr 16, 2025
0e14526
removes detach
BanzaiTokyo Apr 16, 2025
4449c20
If recall argument is provided, device should be None
BanzaiTokyo Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, cast, Optional, Union

import torch

Expand All @@ -15,7 +15,7 @@ def Fbeta(
precision: Optional[Precision] = None,
recall: Optional[Recall] = None,
output_transform: Optional[Callable] = None,
device: Union[str, torch.device] = torch.device("cpu"),
device: Optional[Union[str, torch.device]] = None,
) -> MetricsLambda:
r"""Calculates F-beta score.

Expand Down Expand Up @@ -143,17 +143,26 @@ def thresholded_output_transform(output):
if not (beta > 0):
raise ValueError(f"Beta should be a positive integer, but given {beta}")

if precision is not None and output_transform is not None:
raise ValueError("If precision argument is provided, output_transform should be None")
if precision is not None:
if output_transform is not None:
raise ValueError("If precision argument is provided, output_transform should be None")
if device is not None:
raise ValueError("If precision argument is provided, device should be None")

if recall is not None and output_transform is not None:
raise ValueError("If recall argument is provided, output_transform should be None")
if recall is not None:
if output_transform is not None:
raise ValueError("If recall argument is provided, output_transform should be None")
if device is not None:
raise ValueError("If recall argument is provided, device should be None")

if precision is None and recall is None and device is None:
device = torch.device("cpu")

if precision is None:
precision = Precision(
output_transform=(lambda x: x) if output_transform is None else output_transform,
average=False,
device=device,
device=cast(Union[str, torch.device], recall._device if recall else device),
)
elif precision._average:
raise ValueError("Input precision metric should have average=False")
Expand All @@ -162,7 +171,7 @@ def thresholded_output_transform(output):
recall = Recall(
output_transform=(lambda x: x) if output_transform is None else output_transform,
average=False,
device=device,
device=cast(Union[str, torch.device], precision._device if precision else device),
)
elif recall._average:
raise ValueError("Input recall metric should have average=False")
Expand Down
54 changes: 40 additions & 14 deletions tests/ignite/metrics/test_fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def test_wrong_inputs():
r = Recall(average="samples")
Fbeta(1.0, recall=r)

with pytest.raises(ValueError, match=r"If precision argument is provided, device should be None"):
p = Precision(average=False)
Fbeta(1.0, precision=p, device="cpu")

with pytest.raises(ValueError, match=r"If precision argument is provided, output_transform should be None"):
p = Precision(average=False)
Fbeta(1.0, precision=p, output_transform=lambda x: x)
Expand All @@ -38,30 +42,49 @@ def _output_transform(output):


@pytest.mark.parametrize(
"p, r, average, output_transform",
"precision_cls, recall_cls, average, output_transform",
[
(None, None, False, None),
(None, None, True, None),
(None, None, False, _output_transform),
(None, None, True, _output_transform),
(Precision(average=False), Recall(average=False), False, None),
(Precision(average=False), Recall(average=False), True, None),
(
lambda device: Precision(average=False, device=device),
lambda device: Recall(average=False, device=device),
False,
None,
),
(
lambda device: Precision(average=False, device=device),
lambda device: Recall(average=False, device=device),
True,
None,
),
],
)
def test_integration(p, r, average, output_transform):
np.random.seed(1)
def test_integration(precision_cls, recall_cls, average, output_transform, available_device):
if precision_cls is None:
p = None
else:
p = precision_cls(available_device)
assert p._device == torch.device(available_device)
if recall_cls is None:
r = None
else:
r = recall_cls(available_device)
assert r._device == torch.device(available_device)

n_iters = 10
batch_size = 10
n_classes = 10

y_true = np.arange(0, n_iters * batch_size, dtype="int64") % n_classes
y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
y_true = torch.arange(n_iters * batch_size, dtype=torch.long, device=available_device) % n_classes
y_pred = 0.2 * torch.rand(n_iters * batch_size, n_classes, device=available_device)
for i in range(n_iters * batch_size):
if np.random.rand() > 0.4:
if torch.rand(1) > 0.4:
y_pred[i, y_true[i]] = 1.0
else:
j = np.random.randint(0, n_classes)
j = torch.randint(0, n_classes, size=(1,))
y_pred[i, j] = 0.7

y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
Expand All @@ -71,19 +94,22 @@ def update_fn(engine, batch):
y_true_batch = next(y_true_batch_values)
y_pred_batch = next(y_pred_batch_values)
if output_transform is not None:
return {"y_pred": torch.from_numpy(y_pred_batch), "y": torch.from_numpy(y_true_batch)}
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
return {"y_pred": y_pred_batch, "y": y_true_batch}
return y_pred_batch, y_true_batch

evaluator = Engine(update_fn)

f2 = Fbeta(beta=2.0, average=average, precision=p, recall=r, output_transform=output_transform)
device = None if p is not None and r is not None else available_device
f2 = Fbeta(beta=2.0, average=average, precision=p, recall=r, output_transform=output_transform, device=device)

f2.attach(evaluator, "f2")

data = list(range(n_iters))
state = evaluator.run(data, max_epochs=1)

f2_true = fbeta_score(y_true, np.argmax(y_pred, axis=-1), average="macro" if average else None, beta=2.0)
np.testing.assert_allclose(np.array(f2_true), np.array(state.metrics["f2"]))
f2_true = fbeta_score(y_true, torch.argmax(y_pred, dim=-1), average="macro" if average else None, beta=2.0)
f2_true = np.float32(f2_true) if available_device == "mps" else f2_true
assert f2_true == pytest.approx(state.metrics["f2"])


def _test_distrib_integration(device):
Expand Down
Loading