Skip to content

Commit a2701b8

Browse files
committed
support for Dataset coord in polyval
1 parent 2a6a633 commit a2701b8

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

xarray/core/computation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,7 @@ def polyval(
18831883
xarray.DataArray.polyfit
18841884
numpy.polynomial.polynomial.polyval
18851885
"""
1886+
from .dataset import Dataset
18861887

18871888
deg_coord = coeffs[degree_dim]
18881889

@@ -1896,6 +1897,9 @@ def polyval(
18961897
.broadcast_like(coord)
18971898
.copy(deep=True)
18981899
)
1900+
if isinstance(coord, Dataset) and not isinstance(res, Dataset):
1901+
res = Dataset({var: res for var in coord})
1902+
18991903
deg_idx = len(deg_coord) - 2
19001904
for deg in range(max_deg - 1, -1, -1):
19011905
res *= coord

xarray/tests/test_computation.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,42 +1967,59 @@ def test_polyval_compat(use_dask, use_datetime) -> None:
19671967

19681968

19691969
@pytest.mark.parametrize(
1970-
["coeffs", "expected"],
1970+
["x", "coeffs", "expected"],
19711971
[
19721972
pytest.param(
1973-
xr.DataArray([0, 1], dims="degree"),
19741973
xr.DataArray([1, 2, 3], dims="x"),
1974+
xr.DataArray([2, 3, 4], dims="degree"),
1975+
xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"),
19751976
id="simple",
19761977
),
19771978
pytest.param(
1979+
xr.DataArray([1, 2, 3], dims="x"),
19781980
xr.DataArray([[0, 1], [0, 1]], dims=("y", "degree")),
19791981
xr.DataArray([[1, 1], [2, 2], [3, 3]], dims=("x", "y")),
19801982
id="broadcast-x",
19811983
),
19821984
pytest.param(
1985+
xr.DataArray([1, 2, 3], dims="x"),
19831986
xr.DataArray([[0, 1], [1, 0], [1, 1]], dims=("x", "degree")),
19841987
xr.DataArray([1, 1, 1 + 3], dims="x"),
19851988
id="shared-dim",
19861989
),
19871990
pytest.param(
1991+
xr.DataArray([1, 2, 3], dims="x"),
19881992
xr.DataArray([1, 0, 0], dims="degree", coords={"degree": [2, 1, 0]}),
1989-
xr.DataArray([1, 2**2, 3**2], dims="x"),
1993+
xr.DataArray([1, 2 ** 2, 3 ** 2], dims="x"),
19901994
id="reordered-index",
19911995
),
19921996
pytest.param(
1997+
xr.DataArray([1, 2, 3], dims="x"),
19931998
xr.DataArray([5], dims="degree", coords={"degree": [3]}),
1994-
xr.DataArray([5, 5 * 2**3, 5 * 3**3], dims="x"),
1999+
xr.DataArray([5, 5 * 2 ** 3, 5 * 3 ** 3], dims="x"),
19952000
id="sparse-index",
19962001
),
19972002
pytest.param(
2003+
xr.DataArray([1, 2, 3], dims="x"),
19982004
xr.Dataset({"a": ("degree", [0, 1]), "b": ("degree", [1, 0])}),
19992005
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [1, 1, 1])}),
2000-
id="dataset",
2006+
id="array-dataset",
2007+
),
2008+
pytest.param(
2009+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}),
2010+
xr.DataArray([1, 1], dims="degree"),
2011+
xr.Dataset({"a": ("x", [2, 3, 4]), "b": ("x", [3, 4, 5])}),
2012+
id="dataset-array",
2013+
),
2014+
pytest.param(
2015+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}),
2016+
xr.Dataset({"a": ("degree", [0, 1]), "b": ("degree", [1, 1])}),
2017+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [3, 4, 5])}),
2018+
id="dataset-dataset",
20012019
),
20022020
],
20032021
)
2004-
def test_polyval(coeffs, expected) -> None:
2005-
x = xr.DataArray([1, 2, 3], dims="x")
2022+
def test_polyval(x, coeffs, expected) -> None:
20062023
actual = xr.polyval(x, coeffs)
20072024
xr.testing.assert_allclose(actual, expected)
20082025

0 commit comments

Comments
 (0)