Skip to content

Commit f70c881

Browse files
phausamanndcherian
andauthored
Fix incorrect legend labels for Dataset.plot.scatter (#4411)
* Fix incorrect legend labels for Dataset.plot.scatter Closes #4126 * Update xarray/tests/test_plot.py Co-authored-by: Deepak Cherian <[email protected]>
1 parent bb4c7b4 commit f70c881

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ Bug fixes
8484
- Fix `KeyError` when doing linear interpolation to an nd `DataArray`
8585
that contains NaNs (:pull:`4233`).
8686
By `Jens Svensmark <https://github.com/jenssss>`_
87+
- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`).
88+
By `Peter Hausamann <https://github.com/phausamann>`_.
8789

8890
Documentation
8991
~~~~~~~~~~~~~

xarray/plot/dataset_plot.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,7 @@ def newplotfunc(
339339
ax.set_ylabel(meta_data.get("ylabel"))
340340

341341
if meta_data["add_legend"]:
342-
ax.legend(
343-
handles=primitive,
344-
labels=list(meta_data["hue"].values),
345-
title=meta_data.get("hue_label", None),
346-
)
342+
ax.legend(handles=primitive, title=meta_data.get("hue_label", None))
347343
if meta_data["add_colorbar"]:
348344
cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
349345
if "label" not in cbar_kwargs:

xarray/tests/test_plot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,13 @@ def test_non_numeric_legend(self):
22622262
with pytest.raises(ValueError):
22632263
ds2.plot.scatter(x="A", y="B", hue="hue", hue_style="continuous")
22642264

2265+
def test_legend_labels(self):
2266+
# regression test for #4126: incorrect legend labels
2267+
ds2 = self.ds.copy()
2268+
ds2["hue"] = ["a", "a", "b", "b"]
2269+
lines = ds2.plot.scatter(x="A", y="B", hue="hue")
2270+
assert [t.get_text() for t in lines[0].axes.get_legend().texts] == ["a", "b"]
2271+
22652272
def test_add_legend_by_default(self):
22662273
sc = self.ds.plot.scatter(x="A", y="B", hue="hue")
22672274
assert len(sc.figure.axes) == 2

0 commit comments

Comments
 (0)