Skip to content

[ENH] best_on_top addition in plot_pairwise_scatter #2655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion aeon/visualisation/results/_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def plot_pairwise_scatter(
title=None,
figsize=(8, 8),
color_palette="tab10",
best_on_top=True,
):
"""Plot a scatter that compares datasets' results achieved by two methods.

Expand All @@ -66,6 +67,9 @@ def plot_pairwise_scatter(
Size of the figure.
color_palette : str, default = "tab10"
Color palette to be used for the plot.
best_on_top : bool, default=True
If True, the estimator with better performance is placed on the y-axis (top).
If False, the ordering is reversed.

Returns
-------
Expand Down Expand Up @@ -129,7 +133,7 @@ def plot_pairwise_scatter(
x, y = [min_value, max_value], [min_value, max_value]
ax.plot(x, y, color="black", alpha=0.5, zorder=1)

# Choose the appropriate order for the methods. Best method is shown in the y-axis.
# better estimator on top (y-axis)
if (results_a.mean() <= results_b.mean() and not lower_better) or (
results_a.mean() >= results_b.mean() and lower_better
):
Expand All @@ -143,6 +147,11 @@ def plot_pairwise_scatter(
second = results_b
second_method = method_b

# if best_on_top is False, swap the ordering
if not best_on_top:
first, second = second, first
first_method, second_method = second_method, first_method

differences = [
0 if i - j == 0 else (1 if i - j > 0 else -1) for i, j in zip(first, second)
]
Expand Down
13 changes: 13 additions & 0 deletions aeon/visualisation/results/tests/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@ def test_plot_pairwise_scatter():

assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes)

# best_on_top = False (reversed ordering)
fig_false, ax_false = plot_pairwise_scatter(
res[0],
res[1],
cls[0],
cls[1],
metric="accuracy",
title="Test Plot best_on_top False",
best_on_top=False,
)
plt.gcf().canvas.draw_idle()
assert isinstance(fig_false, plt.Figure) and isinstance(ax_false, plt.Axes)

# Test error handling for metrics
with pytest.raises(ValueError):
plot_pairwise_scatter(
Expand Down
Loading