Skip to content

Expose indexers #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ properties listed below. Proper use of this accessor should be like:

Dataset.xoak.set_index
Dataset.xoak.sel
Dataset.xoak.query

DataArray.xoak
--------------
Expand All @@ -58,6 +59,7 @@ The accessor above is also registered for :py:class:`xarray.DataArray`.

DataArray.xoak.set_index
DataArray.xoak.sel
DataArray.xoak.query

Indexes
-------
Expand Down
1 change: 1 addition & 0 deletions doc/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ Examples
introduction
dask_support
custom_indexes
query
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pooch is missing in environment.yml to load the xarray tutorial dataset in this example notebook.

382 changes: 382 additions & 0 deletions doc/examples/query.ipynb

Large diffs are not rendered by default.

33 changes: 27 additions & 6 deletions src/xoak/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,13 @@ def _get_pos_indexers(self, indices, indexers):

return pos_indexers

def sel(
def query(
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
) -> Union[xr.Dataset, xr.DataArray]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the return type is rather Dict[Hashable, Variable].

It would be nice if query would eventually return the distances too. Maybe returning Tuple[DataArray, DataArray] would make sense instead of trying to return both indexers and distances in the same xarray object. This is closer to what the underlying indexes usually return. Also DataArray (even without any coordinate) is more common than Variable. We can save this for later, though.

"""Selection based on a ball tree index.
"""Directly query the underlying tree index.

The index must have been already built using `xoak.set_index()`.

It behaves mostly like :meth:`xarray.Dataset.sel` and
:meth:`xarray.DataArray.sel` methods, with some limitations:

- Orthogonal indexing is not supported
- For vectorized (point-wise) indexing, you need to supply xarray
objects
Expand All @@ -253,11 +250,35 @@ def sel(
indices = self._query(indexers)

if not isinstance(indices, np.ndarray):
# TODO: remove (see todo below)
# TODO: remove (see TODO in self.sel below)
indices = indices.compute()

pos_indexers = self._get_pos_indexers(indices, indexers)

return xr.Dataset(pos_indexers)

def sel(
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
) -> Union[xr.Dataset, xr.DataArray]:
"""Selection based on an tree index.

The index must have been already built using `xoak.set_index()`.

It behaves mostly like :meth:`xarray.Dataset.sel` and
:meth:`xarray.DataArray.sel` methods, with some limitations:

- Orthogonal indexing is not supported
- For vectorized (point-wise) indexing, you need to supply xarray
objects
- Use it for nearest neighbor lookup only (it implicitly
assumes method="nearest")

This triggers :func:`dask.compute` if the given indexers and/or the index
coordinates are chunked.

"""
pos_indexers = self.query(indexers, **indexers_kwargs)

# TODO: issue in xarray. 1-dimensional xarray.Variables are always considered
# as OuterIndexer, while we want here VectorizedIndexer
# This would also allow lazy selection
Expand Down
57 changes: 54 additions & 3 deletions src/xoak/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,23 @@ def indexer_array_lib(request):


@pytest.fixture(
params=[(('d1',), (200,)), (('d1', 'd2'), (20, 10)), (('d1', 'd2', 'd3'), (4, 10, 5))],
params=[
(('d1',), (200,)),
(('d1', 'd2'), (20, 10)),
(('d1', 'd2', 'd3'), (4, 10, 5)),
],
scope='session',
)
def dataset_dims_shape(request):
return request.param


@pytest.fixture(
params=[(('i1',), (100,)), (('i1', 'i2'), (10, 10)), (('i1', 'i2', 'i3'), (2, 10, 5))],
params=[
(('i1',), (100,)),
(('i1', 'i2'), (10, 10)),
(('i1', 'i2', 'i3'), (2, 10, 5)),
],
scope='session',
)
def indexer_dims_shape(request):
Expand Down Expand Up @@ -64,6 +72,34 @@ def query_brute_force(dataset, dataset_dims_shape, indexer, indexer_dims_shape,
return dataset.isel(indexers=pos_indexers)


def query_brute_force_indexers(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could avoid duplicating this function and have query_brute_force simply return pos_indexers.

dataset, dataset_dims_shape, indexer, indexer_dims_shape, metric='euclidean'
):
"""Find indexers for nearest neighbors using brute-force approach."""

# for lat/lon coordinate, assume they are ordered lat, lon!!
X = np.stack([np.ravel(c) for c in indexer.coords.values()]).T
Y = np.stack([np.ravel(c) for c in dataset.coords.values()]).T

if metric == 'haversine':
X = np.deg2rad(X)
Y = np.deg2rad(Y)

positions, _ = pairwise_distances_argmin_min(X, Y, metric=metric)

dataset_dims, dataset_shape = dataset_dims_shape
indexer_dims, indexer_shape = indexer_dims_shape

u_positions = list(np.unravel_index(positions.ravel(), dataset_shape))

pos_indexers = {
dim: xr.Variable(indexer_dims, ind.reshape(indexer_shape))
for dim, ind in zip(dataset_dims, u_positions)
}

return xr.Dataset(pos_indexers)


@pytest.fixture(scope='session')
def geo_dataset(dataset_dims_shape, dataset_array_lib):
"""Dataset with coords lon and lat on a grid of different shapes."""
Expand Down Expand Up @@ -93,7 +129,22 @@ def geo_indexer(indexer_dims_shape, indexer_array_lib):
@pytest.fixture(scope='session')
def geo_expected(geo_dataset, dataset_dims_shape, geo_indexer, indexer_dims_shape):
return query_brute_force(
geo_dataset, dataset_dims_shape, geo_indexer, indexer_dims_shape, metric='haversine'
geo_dataset,
dataset_dims_shape,
geo_indexer,
indexer_dims_shape,
metric='haversine',
)


@pytest.fixture(scope='session')
def geo_expected_indexers(geo_dataset, dataset_dims_shape, geo_indexer, indexer_dims_shape):
return query_brute_force_indexers(
geo_dataset,
dataset_dims_shape,
geo_indexer,
indexer_dims_shape,
metric='haversine',
)


Expand Down
8 changes: 8 additions & 0 deletions src/xoak/tests/test_s2_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def test_s2point(geo_dataset, geo_indexer, geo_expected):
xr.testing.assert_equal(ds_sel.load(), geo_expected.load())


def test_s2point_via_query(geo_dataset, geo_indexer, geo_expected):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unnecessarily adds a lot of tests via the parametrized fixtures, but it's probably fine for now.

geo_dataset.xoak.set_index(['lat', 'lon'], 's2point')
ds_indexer = geo_dataset.xoak.query(lat=geo_indexer.latitude, lon=geo_indexer.longitude)
ds_sel = geo_dataset.isel(ds_indexer)

xr.testing.assert_equal(ds_sel.load(), geo_expected.load())


def test_s2point_sizeof():
ds = xr.Dataset(coords={'lat': ('points', [0.0, 10.0]), 'lon': ('points', [-5.0, 5.0])})
points = np.array([[0.0, -5.0], [10.0, 5.0]])
Expand Down
9 changes: 9 additions & 0 deletions src/xoak/tests/test_scipy_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def test_scipy_kdtree(xyz_dataset, xyz_indexer, xyz_expected):
xr.testing.assert_equal(ds_sel.load(), xyz_expected.load())


def test_scipy_kdtree_via_indexer(xyz_dataset, xyz_indexer, xyz_expected):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_scipy_kdtree_via_indexer(xyz_dataset, xyz_indexer, xyz_expected):
def test_scipy_kdtree_via_query(xyz_dataset, xyz_indexer, xyz_expected):

xyz_dataset.xoak.set_index(['x', 'y', 'z'], 'scipy_kdtree')

indexers = xyz_dataset.xoak.query(x=xyz_indexer.xx, y=xyz_indexer.yy, z=xyz_indexer.zz)
ds_sel = xyz_dataset.isel(indexers)

xr.testing.assert_equal(ds_sel.load(), xyz_expected.load())


def test_scipy_kdtree_options():
ds = xr.Dataset(coords={'x': ('points', [1, 2]), 'y': ('points', [1, 2])})

Expand Down
12 changes: 9 additions & 3 deletions src/xoak/tests/test_sklearn_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ def test_sklearn_balltree_options():
assert ds.xoak._index._index_adapter._index_options == {'leaf_size': 10}


def test_sklearn_geo_balltree(geo_dataset, geo_indexer, geo_expected):
def test_sklearn_geo_balltree(geo_dataset, geo_indexer, geo_expected, geo_expected_indexers):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split this into two functions like for the other indexes?

geo_dataset.xoak.set_index(['lat', 'lon'], 'sklearn_geo_balltree')
ds_sel = geo_dataset.xoak.sel(lat=geo_indexer.latitude, lon=geo_indexer.longitude)

ds_indexers = geo_dataset.xoak.query(lat=geo_indexer.latitude, lon=geo_indexer.longitude)
xr.testing.assert_equal(ds_indexers.load(), geo_expected_indexers.load())

ds_sel = geo_dataset.xoak.sel(lat=geo_indexer.latitude, lon=geo_indexer.longitude)
xr.testing.assert_equal(ds_sel.load(), geo_expected.load())


Expand All @@ -52,4 +55,7 @@ def test_sklearn_geo_balltree_options():

# sklearn tree classes init options are not exposed as class properties
# user-defined metric should be ignored
assert ds.xoak._index._index_adapter._index_options == {'leaf_size': 10, 'metric': 'haversine'}
assert ds.xoak._index._index_adapter._index_options == {
'leaf_size': 10,
'metric': 'haversine',
}