Closed
Description
In a project of mine I need to interpolate a dask-based xarray between chunk of data.
I made it work using monkey patching. I'm pretty sure that you can write it better but I made it as good as I could.
I hope that what I wrote can help you implement it properly.
from typing import Union, Tuple, Callable, Any, List
import dask.array as da
import numpy as np
import xarray as xr
import xarray.core.missing as m
def interp_func(var: Union[np.ndarray, da.Array],
x: Tuple[xr.DataArray, ...],
new_x: Tuple[xr.DataArray, ...],
method: str,
kwargs: Any) -> da.Array:
"""
multi-dimensional interpolation for array-like. Interpolated axes should be
located in the last position.
Parameters
----------
var: np.ndarray or dask.array.Array
Array to be interpolated. The final dimension is interpolated.
x: a list of 1d array.
Original coordinates. Should not contain NaN.
new_x: a list of 1d array
New coordinates. Should not contain NaN.
method: string
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
1-dimensional itnterpolation.
{'linear', 'nearest'} for multidimensional interpolation
**kwargs:
Optional keyword arguments to be passed to scipy.interpolator
Returns
-------
interpolated: array
Interpolated array
Note
----
This requiers scipy installed.
See Also
--------
scipy.interpolate.interp1d
"""
try:
# try the official interp_func first
res = official_interp_func(var, x, new_x, method, kwargs)
return res
except NotImplementedError:
# may fail if interpolating between chunks
pass
if len(x) == 1:
func, _kwargs = m._get_interpolator(method, vectorizeable_only=True,
**kwargs)
else:
func, _kwargs = m._get_interpolator_nd(method, **kwargs)
# reshape new_x (TODO REMOVE ?)
current_dims = [_x.name for _x in x]
new_x = tuple([_x.set_dims(current_dims) for _x in new_x])
# number of non interpolated dimensions
nconst = var.ndim - len(x)
# duplicate the ghost cells of the array
bnd = {i: "none" for i in range(len(var.shape))}
depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))}
var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd)
# chunks x and duplicate the ghost cells of x
x = tuple(da.from_array(_x, chunks=chunks) for _x, chunks in zip(x, var.chunks[nconst:]))
x_with_ghost = tuple(da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"})
for _x in x)
# compute final chunks
chunks_end = [np.cumsum(sizes) - 1 for _x in x
for sizes in _x.chunks]
chunks_end_with_ghost = [np.cumsum(sizes) - 1 for _x in x_with_ghost
for sizes in _x.chunks]
total_chunks = []
for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)):
l_new_x_ends: List[np.ndarray] = []
for iend, iend_with_ghost in zip(*ce):
arr = np.moveaxis(new_x[dim].data, dim, -1)
arr = arr[tuple([0] * (len(arr.shape) - 1))]
n_no_ghost = (arr <= x[dim][iend]).sum()
n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum()
equil = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int)
l_new_x_ends.append(equil)
new_x_ends = np.array(l_new_x_ends)
chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1])
total_chunks.append(tuple(chunks))
final_chunks = var.chunks[:-len(x)] + tuple(total_chunks)
# chunks new_x
new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x)
# reshape x_with_ghost (TODO REMOVE ?)
x_with_ghost = da.meshgrid(*x_with_ghost, indexing='ij')
# compute on chunks (TODO use drop_axis and new_axis ?)
res = da.map_blocks(_myinterpnd, var_with_ghost, func, _kwargs, len(x_with_ghost), *x_with_ghost, *new_x,
dtype=var.dtype, chunks=final_chunks)
# reshape res and remove empty chunks (TODO REMOVE ?)
res = res.squeeze()
new_chunks = tuple([tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks])
res = res.rechunk(new_chunks)
return res
def _myinterpnd(var: da.Array,
func: Callable[..., Any],
kwargs: Any,
nx: int,
*arrs: da.Array) -> da.Array:
_old_x, _new_x = arrs[:nx], arrs[nx:]
# reshape x (TODO REMOVE ?)
old_x = tuple([np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))]
for dim, tmp in enumerate(_old_x)])
new_x = tuple([xr.DataArray(_x) for _x in _new_x])
return m._interpnd(var, old_x, new_x, func, kwargs)
official_interp_func = m.interp_func
m.interp_func = interp_func
Metadata
Metadata
Assignees
Labels
No labels