-
Notifications
You must be signed in to change notification settings - Fork 5
Expose indexers #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Expose indexers #32
Changes from all commits
99173ad
ea3eed9
2cc512d
eedb5b9
4076088
10edb01
1537e29
bc55f89
6458f50
41d0255
a7e4c81
629432c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ Examples | |
introduction | ||
dask_support | ||
custom_indexes | ||
query | ||
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -224,16 +224,13 @@ def _get_pos_indexers(self, indices, indexers): | |
|
||
return pos_indexers | ||
|
||
def sel( | ||
def query( | ||
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any | ||
) -> Union[xr.Dataset, xr.DataArray]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the return type is rather It would be nice if |
||
"""Selection based on a ball tree index. | ||
"""Directly query the underlying tree index. | ||
|
||
The index must have been already built using `xoak.set_index()`. | ||
|
||
It behaves mostly like :meth:`xarray.Dataset.sel` and | ||
:meth:`xarray.DataArray.sel` methods, with some limitations: | ||
|
||
- Orthogonal indexing is not supported | ||
- For vectorized (point-wise) indexing, you need to supply xarray | ||
objects | ||
|
@@ -253,11 +250,35 @@ def sel( | |
indices = self._query(indexers) | ||
|
||
if not isinstance(indices, np.ndarray): | ||
# TODO: remove (see todo below) | ||
# TODO: remove (see TODO in self.sel below) | ||
indices = indices.compute() | ||
|
||
pos_indexers = self._get_pos_indexers(indices, indexers) | ||
|
||
return xr.Dataset(pos_indexers) | ||
|
||
def sel( | ||
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any | ||
) -> Union[xr.Dataset, xr.DataArray]: | ||
"""Selection based on an tree index. | ||
|
||
The index must have been already built using `xoak.set_index()`. | ||
|
||
It behaves mostly like :meth:`xarray.Dataset.sel` and | ||
:meth:`xarray.DataArray.sel` methods, with some limitations: | ||
|
||
- Orthogonal indexing is not supported | ||
- For vectorized (point-wise) indexing, you need to supply xarray | ||
objects | ||
- Use it for nearest neighbor lookup only (it implicitly | ||
assumes method="nearest") | ||
|
||
This triggers :func:`dask.compute` if the given indexers and/or the index | ||
coordinates are chunked. | ||
|
||
""" | ||
pos_indexers = self.query(indexers, **indexers_kwargs) | ||
|
||
# TODO: issue in xarray. 1-dimensional xarray.Variables are always considered | ||
# as OuterIndexer, while we want here VectorizedIndexer | ||
# This would also allow lazy selection | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,15 +23,23 @@ def indexer_array_lib(request): | |
|
||
|
||
@pytest.fixture( | ||
params=[(('d1',), (200,)), (('d1', 'd2'), (20, 10)), (('d1', 'd2', 'd3'), (4, 10, 5))], | ||
params=[ | ||
(('d1',), (200,)), | ||
(('d1', 'd2'), (20, 10)), | ||
(('d1', 'd2', 'd3'), (4, 10, 5)), | ||
], | ||
scope='session', | ||
) | ||
def dataset_dims_shape(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture( | ||
params=[(('i1',), (100,)), (('i1', 'i2'), (10, 10)), (('i1', 'i2', 'i3'), (2, 10, 5))], | ||
params=[ | ||
(('i1',), (100,)), | ||
(('i1', 'i2'), (10, 10)), | ||
(('i1', 'i2', 'i3'), (2, 10, 5)), | ||
], | ||
scope='session', | ||
) | ||
def indexer_dims_shape(request): | ||
|
@@ -64,6 +72,34 @@ def query_brute_force(dataset, dataset_dims_shape, indexer, indexer_dims_shape, | |
return dataset.isel(indexers=pos_indexers) | ||
|
||
|
||
def query_brute_force_indexers( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could avoid duplicating this function and have |
||
dataset, dataset_dims_shape, indexer, indexer_dims_shape, metric='euclidean' | ||
): | ||
"""Find indexers for nearest neighbors using brute-force approach.""" | ||
|
||
# for lat/lon coordinate, assume they are ordered lat, lon!! | ||
X = np.stack([np.ravel(c) for c in indexer.coords.values()]).T | ||
Y = np.stack([np.ravel(c) for c in dataset.coords.values()]).T | ||
|
||
if metric == 'haversine': | ||
X = np.deg2rad(X) | ||
Y = np.deg2rad(Y) | ||
|
||
positions, _ = pairwise_distances_argmin_min(X, Y, metric=metric) | ||
|
||
dataset_dims, dataset_shape = dataset_dims_shape | ||
indexer_dims, indexer_shape = indexer_dims_shape | ||
|
||
u_positions = list(np.unravel_index(positions.ravel(), dataset_shape)) | ||
|
||
pos_indexers = { | ||
dim: xr.Variable(indexer_dims, ind.reshape(indexer_shape)) | ||
for dim, ind in zip(dataset_dims, u_positions) | ||
} | ||
|
||
return xr.Dataset(pos_indexers) | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def geo_dataset(dataset_dims_shape, dataset_array_lib): | ||
"""Dataset with coords lon and lat on a grid of different shapes.""" | ||
|
@@ -93,7 +129,22 @@ def geo_indexer(indexer_dims_shape, indexer_array_lib): | |
@pytest.fixture(scope='session') | ||
def geo_expected(geo_dataset, dataset_dims_shape, geo_indexer, indexer_dims_shape): | ||
return query_brute_force( | ||
geo_dataset, dataset_dims_shape, geo_indexer, indexer_dims_shape, metric='haversine' | ||
geo_dataset, | ||
dataset_dims_shape, | ||
geo_indexer, | ||
indexer_dims_shape, | ||
metric='haversine', | ||
) | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def geo_expected_indexers(geo_dataset, dataset_dims_shape, geo_indexer, indexer_dims_shape): | ||
return query_brute_force_indexers( | ||
geo_dataset, | ||
dataset_dims_shape, | ||
geo_indexer, | ||
indexer_dims_shape, | ||
metric='haversine', | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,14 @@ def test_s2point(geo_dataset, geo_indexer, geo_expected): | |
xr.testing.assert_equal(ds_sel.load(), geo_expected.load()) | ||
|
||
|
||
def test_s2point_via_query(geo_dataset, geo_indexer, geo_expected): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This unnecessarily adds a lot of tests via the parametrized fixtures, but it's probably fine for now. |
||
geo_dataset.xoak.set_index(['lat', 'lon'], 's2point') | ||
ds_indexer = geo_dataset.xoak.query(lat=geo_indexer.latitude, lon=geo_indexer.longitude) | ||
ds_sel = geo_dataset.isel(ds_indexer) | ||
|
||
xr.testing.assert_equal(ds_sel.load(), geo_expected.load()) | ||
|
||
|
||
def test_s2point_sizeof(): | ||
ds = xr.Dataset(coords={'lat': ('points', [0.0, 10.0]), 'lon': ('points', [-5.0, 5.0])}) | ||
points = np.array([[0.0, -5.0], [10.0, 5.0]]) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -13,6 +13,15 @@ def test_scipy_kdtree(xyz_dataset, xyz_indexer, xyz_expected): | |||||
xr.testing.assert_equal(ds_sel.load(), xyz_expected.load()) | ||||||
|
||||||
|
||||||
def test_scipy_kdtree_via_indexer(xyz_dataset, xyz_indexer, xyz_expected): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
xyz_dataset.xoak.set_index(['x', 'y', 'z'], 'scipy_kdtree') | ||||||
|
||||||
indexers = xyz_dataset.xoak.query(x=xyz_indexer.xx, y=xyz_indexer.yy, z=xyz_indexer.zz) | ||||||
ds_sel = xyz_dataset.isel(indexers) | ||||||
|
||||||
xr.testing.assert_equal(ds_sel.load(), xyz_expected.load()) | ||||||
|
||||||
|
||||||
def test_scipy_kdtree_options(): | ||||||
ds = xr.Dataset(coords={'x': ('points', [1, 2]), 'y': ('points', [1, 2])}) | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,10 +38,13 @@ def test_sklearn_balltree_options(): | |
assert ds.xoak._index._index_adapter._index_options == {'leaf_size': 10} | ||
|
||
|
||
def test_sklearn_geo_balltree(geo_dataset, geo_indexer, geo_expected): | ||
def test_sklearn_geo_balltree(geo_dataset, geo_indexer, geo_expected, geo_expected_indexers): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we split this into two functions like for the other indexes? |
||
geo_dataset.xoak.set_index(['lat', 'lon'], 'sklearn_geo_balltree') | ||
ds_sel = geo_dataset.xoak.sel(lat=geo_indexer.latitude, lon=geo_indexer.longitude) | ||
|
||
ds_indexers = geo_dataset.xoak.query(lat=geo_indexer.latitude, lon=geo_indexer.longitude) | ||
xr.testing.assert_equal(ds_indexers.load(), geo_expected_indexers.load()) | ||
|
||
ds_sel = geo_dataset.xoak.sel(lat=geo_indexer.latitude, lon=geo_indexer.longitude) | ||
xr.testing.assert_equal(ds_sel.load(), geo_expected.load()) | ||
|
||
|
||
|
@@ -52,4 +55,7 @@ def test_sklearn_geo_balltree_options(): | |
|
||
# sklearn tree classes init options are not exposed as class properties | ||
# user-defined metric should be ignored | ||
assert ds.xoak._index._index_adapter._index_options == {'leaf_size': 10, 'metric': 'haversine'} | ||
assert ds.xoak._index._index_adapter._index_options == { | ||
'leaf_size': 10, | ||
'metric': 'haversine', | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooch
is missing inenvironment.yml
to load the xarray tutorial dataset in this example notebook.