Skip to content

Commit 9ec3e07

Browse files
committed
vectorize 1d interpolators
1 parent d26144d commit 9ec3e07

File tree

3 files changed

+68
-30
lines changed

3 files changed

+68
-30
lines changed

xarray/core/missing.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def __init__(
203203

204204
self.method = method
205205
self.cons_kwargs = kwargs
206+
del self.cons_kwargs["axis"]
206207
self.call_kwargs = {"nu": nu, "ext": ext}
207208

208209
if fill_value is not None:
@@ -479,7 +480,8 @@ def _get_interpolator(
479480
interp1d_methods = get_args(Interp1dOptions)
480481
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))
481482

482-
# prioritize scipy.interpolate
483+
# prefer numpy.interp for 1d linear interpolation. This function cannot
484+
# take higher dimensional data but scipy.interp1d can.
483485
if (
484486
method == "linear"
485487
and not kwargs.get("fill_value", None) == "extrapolate"
@@ -489,25 +491,31 @@ def _get_interpolator(
489491
interp_class = NumpyInterpolator
490492

491493
elif method in valid_methods:
494+
kwargs.update(axis=-1)
492495
if method in interp1d_methods:
493496
kwargs.update(method=method)
494497
interp_class = ScipyInterpolator
495-
elif vectorizeable_only:
496-
raise ValueError(
497-
f"{method} is not a vectorizeable interpolator. "
498-
f"Available methods are {interp1d_methods}"
499-
)
500498
elif method == "barycentric":
501499
interp_class = _import_interpolant("BarycentricInterpolator", method)
502500
elif method in ["krogh", "krog"]:
503501
interp_class = _import_interpolant("KroghInterpolator", method)
504502
elif method == "pchip":
505503
interp_class = _import_interpolant("PchipInterpolator", method)
506504
elif method == "spline":
505+
# utils.emit_user_level_warning(
506+
# "The 1d SplineInterpolator class is performing an incorrect calculation and "
507+
# "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
508+
# PendingDeprecationWarning,
509+
# )
510+
if vectorizeable_only:
511+
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
507512
kwargs.update(method=method)
508513
interp_class = SplineInterpolator
509514
elif method == "akima":
510515
interp_class = _import_interpolant("Akima1DInterpolator", method)
516+
elif method == "makima":
517+
kwargs.update(method="makima")
518+
interp_class = _import_interpolant("Akima1DInterpolator", method)
511519
else:
512520
raise ValueError(f"{method} is not a valid scipy interpolator")
513521
else:
@@ -525,6 +533,7 @@ def _get_interpolator_nd(method, **kwargs):
525533

526534
if method in valid_methods:
527535
kwargs.update(method=method)
536+
kwargs.setdefault("bounds_error", False)
528537
interp_class = _import_interpolant("interpn", method)
529538
else:
530539
raise ValueError(
@@ -614,9 +623,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
614623
if not indexes_coords:
615624
return var.copy()
616625

617-
# default behavior
618-
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
619-
620626
result = var
621627
# decompose the interpolation into a succession of independent interpolation
622628
for indep_indexes_coords in decompose_interp(indexes_coords):
@@ -755,8 +761,9 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
755761

756762
def _interp1d(var, x, new_x, func, kwargs):
757763
# x, new_x are tuples of size 1.
758-
x, new_x = x[0], new_x[0]
759-
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
764+
x, new_x = x[0].data, new_x[0].data
765+
766+
rslt = func(x, var, **kwargs)(np.ravel(new_x))
760767
if new_x.ndim > 1:
761768
return reshape(rslt, (var.shape[:-1] + new_x.shape))
762769
if new_x.ndim == 0:

xarray/core/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def copy(
228228
Interp1dOptions = Literal[
229229
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
230230
]
231-
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
231+
InterpolantOptions = Literal[
232+
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
233+
]
232234
InterpOptions = Union[Interp1dOptions, InterpolantOptions]
233235

234236
DatetimeUnitOptions = Literal[

xarray/tests/test_interp.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,29 +132,55 @@ def func(obj, new_x):
132132
assert_allclose(actual, expected)
133133

134134

135-
@pytest.mark.parametrize("use_dask", [False, True])
136-
def test_interpolate_vectorize(use_dask: bool) -> None:
135+
@pytest.mark.parametrize(
136+
"use_dask, method",
137+
(
138+
(False, "linear"),
139+
(False, "akima"),
140+
(False, "makima"),
141+
(True, "linear"),
142+
(True, "akima"),
143+
),
144+
)
145+
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
137146
if not has_scipy:
138147
pytest.skip("scipy is not installed.")
139148

140149
if not has_dask and use_dask:
141150
pytest.skip("dask is not installed in the environment.")
142151

143152
# scipy interpolation for the reference
144-
def func(obj, dim, new_x):
153+
def func(obj, dim, new_x, method):
154+
scipy_kwargs = {}
155+
interpolant_options = {
156+
"barycentric": "BarycentricInterpolator",
157+
"krogh": "KroghInterpolator",
158+
"pchip": "PchipInterpolator",
159+
"akima": "Akima1DInterpolator",
160+
"makima": "Akima1DInterpolator",
161+
}
162+
145163
shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
146164
for s in new_x.shape[::-1]:
147165
shape.insert(obj.get_axis_num(dim), s)
148166

149-
return scipy.interpolate.interp1d(
150-
da[dim],
151-
obj.data,
152-
axis=obj.get_axis_num(dim),
153-
bounds_error=False,
154-
fill_value=np.nan,
155-
)(new_x).reshape(shape)
167+
if method in interpolant_options:
168+
interpolant = getattr(scipy.interpolate, interpolant_options[method])
169+
if method == "makima":
170+
scipy_kwargs["method"] = method
171+
return interpolant(
172+
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
173+
)(new_x).reshape(shape)
174+
else:
175+
scipy_kwargs["kind"] = method
176+
scipy_kwargs["bounds_error"] = False
177+
scipy_kwargs["fill_value"] = np.nan
178+
return scipy.interpolate.interp1d(
179+
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
180+
)(new_x).reshape(shape)
156181

157182
da = get_example_data(0)
183+
158184
if use_dask:
159185
da = da.chunk({"y": 5})
160186

@@ -165,17 +191,17 @@ def func(obj, dim, new_x):
165191
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
166192
)
167193

168-
actual = da.interp(x=xdest, method="linear")
194+
actual = da.interp(x=xdest, method=method)
169195

170196
expected = xr.DataArray(
171-
func(da, "x", xdest),
197+
func(da, "x", xdest, method),
172198
dims=["z", "y"],
173199
coords={
174200
"z": xdest["z"],
175201
"z2": xdest["z2"],
176202
"y": da["y"],
177203
"x": ("z", xdest.values),
178-
"x2": ("z", func(da["x2"], "x", xdest)),
204+
"x2": ("z", func(da["x2"], "x", xdest, method)),
179205
},
180206
)
181207
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
@@ -191,18 +217,18 @@ def func(obj, dim, new_x):
191217
},
192218
)
193219

194-
actual = da.interp(x=xdest, method="linear")
220+
actual = da.interp(x=xdest, method=method)
195221

196222
expected = xr.DataArray(
197-
func(da, "x", xdest),
223+
func(da, "x", xdest, method),
198224
dims=["z", "w", "y"],
199225
coords={
200226
"z": xdest["z"],
201227
"w": xdest["w"],
202228
"z2": xdest["z2"],
203229
"y": da["y"],
204230
"x": (("z", "w"), xdest.data),
205-
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
231+
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
206232
},
207233
)
208234
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
@@ -404,7 +430,7 @@ def test_errors(use_dask: bool) -> None:
404430
pytest.skip("dask is not installed in the environment.")
405431
da = da.chunk()
406432

407-
for method in ["akima", "spline"]:
433+
for method in ["spline"]:
408434
with pytest.raises(ValueError):
409435
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]
410436

@@ -922,7 +948,10 @@ def test_interp1d_bounds_error() -> None:
922948
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
923949
],
924950
)
925-
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
951+
def test_coord_attrs(
952+
x,
953+
expect_same_attrs: bool,
954+
) -> None:
926955
base_attrs = dict(foo="bar")
927956
ds = xr.Dataset(
928957
data_vars=dict(a=2 * np.arange(5)),

0 commit comments

Comments
 (0)