@@ -132,29 +132,55 @@ def func(obj, new_x):
132
132
assert_allclose (actual , expected )
133
133
134
134
135
- @pytest .mark .parametrize ("use_dask" , [False , True ])
136
- def test_interpolate_vectorize (use_dask : bool ) -> None :
135
+ @pytest .mark .parametrize (
136
+ "use_dask, method" ,
137
+ (
138
+ (False , "linear" ),
139
+ (False , "akima" ),
140
+ (False , "makima" ),
141
+ (True , "linear" ),
142
+ (True , "akima" ),
143
+ ),
144
+ )
145
+ def test_interpolate_vectorize (use_dask : bool , method : str ) -> None :
137
146
if not has_scipy :
138
147
pytest .skip ("scipy is not installed." )
139
148
140
149
if not has_dask and use_dask :
141
150
pytest .skip ("dask is not installed in the environment." )
142
151
143
152
# scipy interpolation for the reference
144
- def func (obj , dim , new_x ):
153
+ def func (obj , dim , new_x , method ):
154
+ scipy_kwargs = {}
155
+ interpolant_options = {
156
+ "barycentric" : "BarycentricInterpolator" ,
157
+ "krogh" : "KroghInterpolator" ,
158
+ "pchip" : "PchipInterpolator" ,
159
+ "akima" : "Akima1DInterpolator" ,
160
+ "makima" : "Akima1DInterpolator" ,
161
+ }
162
+
145
163
shape = [s for i , s in enumerate (obj .shape ) if i != obj .get_axis_num (dim )]
146
164
for s in new_x .shape [::- 1 ]:
147
165
shape .insert (obj .get_axis_num (dim ), s )
148
166
149
- return scipy .interpolate .interp1d (
150
- da [dim ],
151
- obj .data ,
152
- axis = obj .get_axis_num (dim ),
153
- bounds_error = False ,
154
- fill_value = np .nan ,
155
- )(new_x ).reshape (shape )
167
+ if method in interpolant_options :
168
+ interpolant = getattr (scipy .interpolate , interpolant_options [method ])
169
+ if method == "makima" :
170
+ scipy_kwargs ["method" ] = method
171
+ return interpolant (
172
+ da [dim ], obj .data , axis = obj .get_axis_num (dim ), ** scipy_kwargs
173
+ )(new_x ).reshape (shape )
174
+ else :
175
+ scipy_kwargs ["kind" ] = method
176
+ scipy_kwargs ["bounds_error" ] = False
177
+ scipy_kwargs ["fill_value" ] = np .nan
178
+ return scipy .interpolate .interp1d (
179
+ da [dim ], obj .data , axis = obj .get_axis_num (dim ), ** scipy_kwargs
180
+ )(new_x ).reshape (shape )
156
181
157
182
da = get_example_data (0 )
183
+
158
184
if use_dask :
159
185
da = da .chunk ({"y" : 5 })
160
186
@@ -165,17 +191,17 @@ def func(obj, dim, new_x):
165
191
coords = {"z" : np .random .randn (30 ), "z2" : ("z" , np .random .randn (30 ))},
166
192
)
167
193
168
- actual = da .interp (x = xdest , method = "linear" )
194
+ actual = da .interp (x = xdest , method = method )
169
195
170
196
expected = xr .DataArray (
171
- func (da , "x" , xdest ),
197
+ func (da , "x" , xdest , method ),
172
198
dims = ["z" , "y" ],
173
199
coords = {
174
200
"z" : xdest ["z" ],
175
201
"z2" : xdest ["z2" ],
176
202
"y" : da ["y" ],
177
203
"x" : ("z" , xdest .values ),
178
- "x2" : ("z" , func (da ["x2" ], "x" , xdest )),
204
+ "x2" : ("z" , func (da ["x2" ], "x" , xdest , method )),
179
205
},
180
206
)
181
207
assert_allclose (actual , expected .transpose ("z" , "y" , transpose_coords = True ))
@@ -191,18 +217,18 @@ def func(obj, dim, new_x):
191
217
},
192
218
)
193
219
194
- actual = da .interp (x = xdest , method = "linear" )
220
+ actual = da .interp (x = xdest , method = method )
195
221
196
222
expected = xr .DataArray (
197
- func (da , "x" , xdest ),
223
+ func (da , "x" , xdest , method ),
198
224
dims = ["z" , "w" , "y" ],
199
225
coords = {
200
226
"z" : xdest ["z" ],
201
227
"w" : xdest ["w" ],
202
228
"z2" : xdest ["z2" ],
203
229
"y" : da ["y" ],
204
230
"x" : (("z" , "w" ), xdest .data ),
205
- "x2" : (("z" , "w" ), func (da ["x2" ], "x" , xdest )),
231
+ "x2" : (("z" , "w" ), func (da ["x2" ], "x" , xdest , method )),
206
232
},
207
233
)
208
234
assert_allclose (actual , expected .transpose ("z" , "w" , "y" , transpose_coords = True ))
@@ -404,7 +430,7 @@ def test_errors(use_dask: bool) -> None:
404
430
pytest .skip ("dask is not installed in the environment." )
405
431
da = da .chunk ()
406
432
407
- for method in ["akima" , " spline" ]:
433
+ for method in ["spline" ]:
408
434
with pytest .raises (ValueError ):
409
435
da .interp (x = [0.5 , 1.5 ], method = method ) # type: ignore[arg-type]
410
436
@@ -922,7 +948,10 @@ def test_interp1d_bounds_error() -> None:
922
948
(("x" , np .array ([0 , 0.5 , 1 , 2 ]), dict (unit = "s" )), False ),
923
949
],
924
950
)
925
- def test_coord_attrs (x , expect_same_attrs : bool ) -> None :
951
+ def test_coord_attrs (
952
+ x ,
953
+ expect_same_attrs : bool ,
954
+ ) -> None :
926
955
base_attrs = dict (foo = "bar" )
927
956
ds = xr .Dataset (
928
957
data_vars = dict (a = 2 * np .arange (5 )),
0 commit comments