Skip to content

Commit 425874c

Browse files
committed
vectorize 1d interpolators
review changes
1 parent 52f13d4 commit 425874c

File tree

6 files changed

+102
-47
lines changed

6 files changed

+102
-47
lines changed

xarray/core/dataarray.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,12 +2224,12 @@ def interp(
22242224
22252225
Performs univariate or multivariate interpolation of a DataArray onto
22262226
new coordinates using scipy's interpolation routines. If interpolating
2227-
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
2228-
called. When interpolating along multiple existing dimensions, an
2227+
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
2228+
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
2229+
is called. When interpolating along multiple existing dimensions, an
22292230
attempt is made to decompose the interpolation into multiple
2230-
1-dimensional interpolations. If this is possible,
2231-
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
2232-
:py:func:`scipy.interpolate.interpn` is called.
2231+
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called.
2232+
Otherwise, :py:func:`scipy.interpolate.interpn` is called.
22332233
22342234
Parameters
22352235
----------

xarray/core/dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3885,12 +3885,12 @@ def interp(
38853885
38863886
Performs univariate or multivariate interpolation of a Dataset onto
38873887
new coordinates using scipy's interpolation routines. If interpolating
3888-
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
3889-
called. When interpolating along multiple existing dimensions, an
3888+
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
3889+
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
3890+
is called. When interpolating along multiple existing dimensions, an
38903891
attempt is made to decompose the interpolation into multiple
3891-
1-dimensional interpolations. If this is possible,
3892-
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
3893-
:py:func:`scipy.interpolate.interpn` is called.
3892+
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator
3893+
is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called.
38943894
38953895
Parameters
38963896
----------

xarray/core/missing.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(
138138
copy=False,
139139
bounds_error=False,
140140
order=None,
141+
axis=-1,
141142
**kwargs,
142143
):
143144
from scipy.interpolate import interp1d
@@ -173,6 +174,7 @@ def __init__(
173174
bounds_error=bounds_error,
174175
assume_sorted=assume_sorted,
175176
copy=copy,
177+
axis=axis,
176178
**self.cons_kwargs,
177179
)
178180

@@ -203,6 +205,7 @@ def __init__(
203205

204206
self.method = method
205207
self.cons_kwargs = kwargs
208+
del self.cons_kwargs["axis"]
206209
self.call_kwargs = {"nu": nu, "ext": ext}
207210

208211
if fill_value is not None:
@@ -479,7 +482,8 @@ def _get_interpolator(
479482
interp1d_methods = get_args(Interp1dOptions)
480483
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))
481484

482-
# prioritize scipy.interpolate
485+
# prefer numpy.interp for 1d linear interpolation. This function cannot
486+
# take higher dimensional data but scipy.interp1d can.
483487
if (
484488
method == "linear"
485489
and not kwargs.get("fill_value", None) == "extrapolate"
@@ -489,24 +493,37 @@ def _get_interpolator(
489493
interp_class = NumpyInterpolator
490494

491495
elif method in valid_methods:
496+
kwargs.update(axis=-1)
492497
if method in interp1d_methods:
493498
kwargs.update(method=method)
494499
interp_class = ScipyInterpolator
495-
elif vectorizeable_only:
496-
raise ValueError(
497-
f"{method} is not a vectorizeable interpolator. "
498-
f"Available methods are {interp1d_methods}"
499-
)
500500
elif method == "barycentric":
501+
kwargs.update(axis=-1)
501502
interp_class = _import_interpolant("BarycentricInterpolator", method)
502503
elif method in ["krogh", "krog"]:
504+
kwargs.update(axis=-1)
503505
interp_class = _import_interpolant("KroghInterpolator", method)
504506
elif method == "pchip":
507+
kwargs.update(axis=-1)
505508
interp_class = _import_interpolant("PchipInterpolator", method)
506509
elif method == "spline":
510+
# utils.emit_user_level_warning(
511+
# "The 1d SplineInterpolator class is performing an incorrect calculation and "
512+
# "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
513+
# PendingDeprecationWarning,
514+
# )
515+
if vectorizeable_only:
516+
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
507517
kwargs.update(method=method)
508518
interp_class = SplineInterpolator
509519
elif method == "akima":
520+
kwargs.update(axis=-1)
521+
interp_class = _import_interpolant("Akima1DInterpolator", method)
522+
elif method == "makima":
523+
kwargs.update(method="makima", axis=-1)
524+
interp_class = _import_interpolant("Akima1DInterpolator", method)
525+
elif method == "makima":
526+
kwargs.update(method="makima")
510527
interp_class = _import_interpolant("Akima1DInterpolator", method)
511528
else:
512529
raise ValueError(f"{method} is not a valid scipy interpolator")
@@ -525,6 +542,7 @@ def _get_interpolator_nd(method, **kwargs):
525542

526543
if method in valid_methods:
527544
kwargs.update(method=method)
545+
kwargs.setdefault("bounds_error", False)
528546
interp_class = _import_interpolant("interpn", method)
529547
else:
530548
raise ValueError(
@@ -614,9 +632,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
614632
if not indexes_coords:
615633
return var.copy()
616634

617-
# default behavior
618-
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
619-
620635
result = var
621636
# decompose the interpolation into a succession of independent interpolation
622637
for indep_indexes_coords in decompose_interp(indexes_coords):
@@ -663,8 +678,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
663678
new_x : a list of 1d array
664679
New coordinates. Should not contain NaN.
665680
method : string
666-
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
667-
1-dimensional interpolation.
681+
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima',
682+
'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation.
668683
{'linear', 'nearest'} for multidimensional interpolation
669684
**kwargs
670685
Optional keyword arguments to be passed to scipy.interpolator
@@ -756,7 +771,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
756771
def _interp1d(var, x, new_x, func, kwargs):
757772
# x, new_x are tuples of size 1.
758773
x, new_x = x[0], new_x[0]
759-
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
774+
rslt = func(x, var, **kwargs)(np.ravel(new_x))
760775
if new_x.ndim > 1:
761776
return reshape(rslt, (var.shape[:-1] + new_x.shape))
762777
if new_x.ndim == 0:

xarray/core/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def copy(
228228
Interp1dOptions = Literal[
229229
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
230230
]
231-
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
231+
InterpolantOptions = Literal[
232+
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
233+
]
232234
InterpOptions = Union[Interp1dOptions, InterpolantOptions]
233235

234236
DatetimeUnitOptions = Literal[

xarray/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _importorskip(
8787

8888
has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
8989
has_scipy, requires_scipy = _importorskip("scipy")
90+
has_scipy_ge_1_13, requires_scipy_ge_1_13 = _importorskip("scipy", "1.13")
9091
with warnings.catch_warnings():
9192
warnings.filterwarnings(
9293
"ignore",

xarray/tests/test_interp.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
assert_identical,
1717
has_dask,
1818
has_scipy,
19+
has_scipy_ge_1_13,
1920
requires_cftime,
2021
requires_dask,
2122
requires_scipy,
@@ -132,29 +133,62 @@ def func(obj, new_x):
132133
assert_allclose(actual, expected)
133134

134135

135-
@pytest.mark.parametrize("use_dask", [False, True])
136-
def test_interpolate_vectorize(use_dask: bool) -> None:
137-
if not has_scipy:
138-
pytest.skip("scipy is not installed.")
139-
140-
if not has_dask and use_dask:
141-
pytest.skip("dask is not installed in the environment.")
142-
136+
@requires_scipy
137+
@pytest.mark.parametrize(
138+
"use_dask, method",
139+
(
140+
(False, "linear"),
141+
(False, "akima"),
142+
pytest.param(
143+
False,
144+
"makima",
145+
marks=pytest.mark.skipif(not has_scipy_ge_1_13, reason="scipy too old"),
146+
),
147+
pytest.param(
148+
True,
149+
"linear",
150+
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
151+
),
152+
(True, "akima"),
153+
),
154+
)
155+
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
143156
# scipy interpolation for the reference
144-
def func(obj, dim, new_x):
157+
def func(obj, dim, new_x, method):
158+
scipy_kwargs = {}
159+
interpolant_options = {
160+
"barycentric": scipy.interpolate.BarycentricInterpolator,
161+
"krogh": scipy.interpolate.KroghInterpolator,
162+
"pchip": scipy.interpolate.PchipInterpolator,
163+
"akima": scipy.interpolate.Akima1DInterpolator,
164+
"makima": scipy.interpolate.Akima1DInterpolator,
165+
}
166+
145167
shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
146168
for s in new_x.shape[::-1]:
147169
shape.insert(obj.get_axis_num(dim), s)
148170

149-
return scipy.interpolate.interp1d(
150-
da[dim],
151-
obj.data,
152-
axis=obj.get_axis_num(dim),
153-
bounds_error=False,
154-
fill_value=np.nan,
155-
)(new_x).reshape(shape)
171+
if method in interpolant_options:
172+
interpolant = interpolant_options[method]
173+
if method == "makima":
174+
scipy_kwargs["method"] = method
175+
return interpolant(
176+
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
177+
)(new_x).reshape(shape)
178+
else:
179+
180+
return scipy.interpolate.interp1d(
181+
da[dim],
182+
obj.data,
183+
axis=obj.get_axis_num(dim),
184+
kind=method,
185+
bounds_error=False,
186+
fill_value=np.nan,
187+
**scipy_kwargs,
188+
)(new_x).reshape(shape)
156189

157190
da = get_example_data(0)
191+
158192
if use_dask:
159193
da = da.chunk({"y": 5})
160194

@@ -165,17 +199,17 @@ def func(obj, dim, new_x):
165199
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
166200
)
167201

168-
actual = da.interp(x=xdest, method="linear")
202+
actual = da.interp(x=xdest, method=method)
169203

170204
expected = xr.DataArray(
171-
func(da, "x", xdest),
205+
func(da, "x", xdest, method),
172206
dims=["z", "y"],
173207
coords={
174208
"z": xdest["z"],
175209
"z2": xdest["z2"],
176210
"y": da["y"],
177211
"x": ("z", xdest.values),
178-
"x2": ("z", func(da["x2"], "x", xdest)),
212+
"x2": ("z", func(da["x2"], "x", xdest, method)),
179213
},
180214
)
181215
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
@@ -191,18 +225,18 @@ def func(obj, dim, new_x):
191225
},
192226
)
193227

194-
actual = da.interp(x=xdest, method="linear")
228+
actual = da.interp(x=xdest, method=method)
195229

196230
expected = xr.DataArray(
197-
func(da, "x", xdest),
231+
func(da, "x", xdest, method),
198232
dims=["z", "w", "y"],
199233
coords={
200234
"z": xdest["z"],
201235
"w": xdest["w"],
202236
"z2": xdest["z2"],
203237
"y": da["y"],
204238
"x": (("z", "w"), xdest.data),
205-
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
239+
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
206240
},
207241
)
208242
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
@@ -404,7 +438,7 @@ def test_errors(use_dask: bool) -> None:
404438
pytest.skip("dask is not installed in the environment.")
405439
da = da.chunk()
406440

407-
for method in ["akima", "spline"]:
441+
for method in ["spline"]:
408442
with pytest.raises(ValueError):
409443
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]
410444

@@ -922,7 +956,10 @@ def test_interp1d_bounds_error() -> None:
922956
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
923957
],
924958
)
925-
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
959+
def test_coord_attrs(
960+
x,
961+
expect_same_attrs: bool,
962+
) -> None:
926963
base_attrs = dict(foo="bar")
927964
ds = xr.Dataset(
928965
data_vars=dict(a=2 * np.arange(5)),

0 commit comments

Comments
 (0)