Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 82 additions & 60 deletions tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ def test_data(request):


@pytest.mark.parametrize("n_times", range(5))
def test_multiclass_input(n_times, test_data):
def test_multiclass_input(n_times, test_data, available_device):
y_pred, y, num_classes, batch_size = test_data
cm = ConfusionMatrix(num_classes=num_classes)
cm = ConfusionMatrix(num_classes=num_classes, device=available_device)
assert cm._device == torch.device(available_device)
cm.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
Expand All @@ -85,21 +86,22 @@ def test_multiclass_input(n_times, test_data):
else:
cm.update((y_pred, y))

np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())
np_y_pred = y_pred.cpu().numpy().argmax(axis=1).ravel()
np_y = y.cpu().numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().cpu().numpy())


def test_ignored_out_of_num_classes_indices():
def test_ignored_out_of_num_classes_indices(available_device):
num_classes = 21
cm = ConfusionMatrix(num_classes=num_classes)
cm = ConfusionMatrix(num_classes=num_classes, device=available_device)
assert cm._device == torch.device(available_device)

y_pred = torch.rand(4, num_classes, 12, 10)
y = torch.randint(0, 255, size=(4, 12, 10)).long()
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())
np_y_pred = y_pred.cpu().numpy().argmax(axis=1).ravel()
np_y = y.cpu().numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().cpu().numpy())


def get_y_true_y_pred():
Expand Down Expand Up @@ -127,9 +129,10 @@ def compute_th_y_true_y_logits(y_true, y_pred):
return th_y_true, th_y_logits


def test_multiclass_images():
def test_multiclass_images(available_device):
num_classes = 3
cm = ConfusionMatrix(num_classes=num_classes)
cm = ConfusionMatrix(num_classes=num_classes, device=available_device)
assert cm._device == torch.device(available_device)

y_true, y_pred = get_y_true_y_pred()

Expand All @@ -142,13 +145,14 @@ def test_multiclass_images():
output = (th_y_logits, th_y_true)
cm.update(output)

res = cm.compute().numpy()
res = cm.compute().cpu().numpy()

assert np.all(true_res == res)

# Another test on batch of 2 images
num_classes = 3
cm = ConfusionMatrix(num_classes=num_classes)
cm = ConfusionMatrix(num_classes=num_classes, device=available_device)
assert cm._device == torch.device(available_device)

# Create a batch of two images:
th_y_true1 = torch.from_numpy(y_true).reshape(1, 30, 30)
Expand All @@ -173,10 +177,12 @@ def test_multiclass_images():
# Update metric & compute
output = (th_y_logits, th_y_true)
cm.update(output)
res = cm.compute().numpy()
res = cm.compute().cpu().numpy()

# Compute confusion matrix with sklearn
true_res = confusion_matrix(th_y_true.numpy().reshape(-1), np.argmax(th_y_logits.numpy(), axis=1).reshape(-1))
true_res = confusion_matrix(
th_y_true.cpu().numpy().reshape(-1), np.argmax(th_y_logits.cpu().numpy(), axis=1).reshape(-1)
)

assert np.all(true_res == res)

Expand All @@ -200,7 +206,7 @@ def test_iou_wrong_input():


@pytest.mark.parametrize("average", [None, "samples"])
def test_iou(average):
def test_iou(average, available_device):
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

Expand All @@ -212,24 +218,26 @@ def test_iou(average):
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()

cm = ConfusionMatrix(num_classes=3, average=average)
cm = ConfusionMatrix(num_classes=3, average=average, device=available_device)
assert cm._device == torch.device(available_device)
iou_metric = IoU(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = iou_metric.compute().numpy()
res = iou_metric.compute().cpu().numpy()

assert np.all(res == true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
iou_metric = IoU(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = iou_metric.compute().numpy()
res = iou_metric.compute().cpu().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

Expand All @@ -238,7 +246,7 @@ def test_iou(average):
IoU(cm)


def test_miou():
def test_miou(available_device):
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

Expand All @@ -252,136 +260,146 @@ def test_miou():

true_res_ = np.mean(true_res)

cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
iou_metric = mIoU(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = iou_metric.compute().numpy()
res = iou_metric.compute().cpu().numpy()

assert pytest.approx(res) == true_res_

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
iou_metric = mIoU(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = iou_metric.compute().numpy()
res = iou_metric.compute().cpu().numpy()
true_res_ = np.mean(true_res[:ignore_index] + true_res[ignore_index + 1 :])
assert pytest.approx(res) == true_res_, f"{ignore_index}: {res} vs {true_res_}"


def test_cm_accuracy():
def test_cm_accuracy(available_device):
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_acc = accuracy_score(y_true.reshape(-1), y_pred.reshape(-1))

cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
acc_metric = cmAccuracy(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = acc_metric.compute().numpy()
res = acc_metric.compute().cpu().numpy()

assert pytest.approx(res) == true_acc


def test_cm_precision():
def test_cm_precision(available_device):
y_true, y_pred = np.random.randint(0, 10, size=(1000,)), np.random.randint(0, 10, size=(1000,))
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_pr = precision_score(y_true.reshape(-1), y_pred.reshape(-1), average="macro")

cm = ConfusionMatrix(num_classes=10)
cm = ConfusionMatrix(num_classes=10, device=available_device)
assert cm._device == torch.device(available_device)
pr_metric = cmPrecision(cm, average=True)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = pr_metric.compute().numpy()
res = pr_metric.compute().cpu().numpy()

assert pytest.approx(res) == true_pr

true_pr = precision_score(y_true.reshape(-1), y_pred.reshape(-1), average=None)
cm = ConfusionMatrix(num_classes=10)
cm = ConfusionMatrix(num_classes=10, device=available_device)
assert cm._device == torch.device(available_device)
pr_metric = cmPrecision(cm, average=False)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = pr_metric.compute().numpy()
res = pr_metric.compute().cpu().numpy()

assert np.all(res == true_pr)


def test_cm_recall():
def test_cm_recall(available_device):
y_true, y_pred = np.random.randint(0, 10, size=(1000,)), np.random.randint(0, 10, size=(1000,))
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_re = recall_score(y_true.reshape(-1), y_pred.reshape(-1), average="macro")

cm = ConfusionMatrix(num_classes=10)
cm = ConfusionMatrix(num_classes=10, device=available_device)
assert cm._device == torch.device(available_device)
re_metric = cmRecall(cm, average=True)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = re_metric.compute().numpy()
res = re_metric.compute().cpu().numpy()

assert pytest.approx(res) == true_re

true_re = recall_score(y_true.reshape(-1), y_pred.reshape(-1), average=None)
cm = ConfusionMatrix(num_classes=10)
cm = ConfusionMatrix(num_classes=10, device=available_device)
assert cm._device == torch.device(available_device)
re_metric = cmRecall(cm, average=False)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = re_metric.compute().numpy()
res = re_metric.compute().cpu().numpy()

assert np.all(res == true_re)


def test_cm_with_average():
def test_cm_with_average(available_device):
num_classes = 5
y_pred = torch.rand(40, num_classes)
y = torch.randint(0, num_classes, size=(40,)).long()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
np_y_pred = y_pred.cpu().numpy().argmax(axis=1).ravel()
np_y = y.cpu().numpy().ravel()

cm = ConfusionMatrix(num_classes=num_classes, average="samples")
cm = ConfusionMatrix(num_classes=num_classes, average="samples", device=available_device)
assert cm._device == torch.device(available_device)
cm.update((y_pred, y))
true_res = confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) * 1.0 / len(np_y)
res = cm.compute().numpy()
res = cm.compute().cpu().numpy()
np.testing.assert_almost_equal(true_res, res)

cm = ConfusionMatrix(num_classes=num_classes, average="recall")
cm = ConfusionMatrix(num_classes=num_classes, average="recall", device=available_device)
assert cm._device == torch.device(available_device)
cm.update((y_pred, y))
true_re = recall_score(np_y, np_y_pred, average=None, labels=list(range(num_classes)))
res = cm.compute().numpy().diagonal()
res = cm.compute().cpu().numpy().diagonal()
np.testing.assert_almost_equal(true_re, res)

res = cm.compute().numpy()
res = cm.compute().cpu().numpy()
true_res = confusion_matrix(np_y, np_y_pred, normalize="true")
np.testing.assert_almost_equal(true_res, res)

cm = ConfusionMatrix(num_classes=num_classes, average="precision")
cm = ConfusionMatrix(num_classes=num_classes, average="precision", device=available_device)
assert cm._device == torch.device(available_device)
cm.update((y_pred, y))
true_pr = precision_score(np_y, np_y_pred, average=None, labels=list(range(num_classes)))
res = cm.compute().numpy().diagonal()
res = cm.compute().cpu().numpy().diagonal()
np.testing.assert_almost_equal(true_pr, res)

res = cm.compute().numpy()
res = cm.compute().cpu().numpy()
true_res = confusion_matrix(np_y, np_y_pred, normalize="pred")
np.testing.assert_almost_equal(true_res, res)

Expand All @@ -404,7 +422,7 @@ def test_dice_coefficient_wrong_input():
DiceCoefficient(cm, ignore_index=11)


def test_dice_coefficient():
def test_dice_coefficient(available_device):
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

Expand All @@ -418,23 +436,25 @@ def test_dice_coefficient():
union = bin_y_true | bin_y_pred
true_res[index] = 2.0 * intersection.sum() / (union.sum() + intersection.sum())

cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
dice_metric = DiceCoefficient(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = dice_metric.compute().numpy()
res = dice_metric.compute().cpu().numpy()
np.testing.assert_allclose(res, true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
dice_metric = DiceCoefficient(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = dice_metric.compute().numpy()
res = dice_metric.compute().cpu().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

Expand Down Expand Up @@ -529,7 +549,7 @@ def _test_distrib_accumulator_device(device):


@pytest.mark.parametrize("average", [None, "samples"])
def test_jaccard_index(average):
def test_jaccard_index(average, available_device):
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

Expand All @@ -541,24 +561,26 @@ def test_jaccard_index(average):
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()

cm = ConfusionMatrix(num_classes=3, average=average)
cm = ConfusionMatrix(num_classes=3, average=average, device=available_device)
assert cm._device == torch.device(available_device)
jaccard_index = JaccardIndex(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = jaccard_index.compute().numpy()
res = jaccard_index.compute().cpu().numpy()

assert np.all(res == true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
cm = ConfusionMatrix(num_classes=3, device=available_device)
assert cm._device == torch.device(available_device)
jaccard_index_metric = JaccardIndex(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = jaccard_index_metric.compute().numpy()
res = jaccard_index_metric.compute().cpu().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

Expand Down
Loading