Skip to content

Integration with Xarray NDPointIndex #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-docstring-first
- id: check-yaml
- id: double-quote-string-fixer

- repo: https://github.com/ambv/black
rev: 20.8b1
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
args: ["--line-length", "100", "--skip-string-normalization"]

- repo: https://gitlab.com/PyCQA/flake8
rev: 3.8.4
hooks:
- id: flake8

- repo: https://github.com/PyCQA/isort
rev: 5.6.4
rev: 6.0.1
hooks:
- id: isort
20 changes: 18 additions & 2 deletions src/xoak/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from pkg_resources import DistributionNotFound, get_distribution

from .accessor import XoakAccessor
from .index import IndexAdapter, IndexRegistry
from xoak.accessor import XoakAccessor
from xoak.index import IndexAdapter, IndexRegistry
from xoak.tree_adapters import (
S2PointTreeAdapter,
SklearnBallTreeAdapter,
SklearnGeoBallTreeAdapter,
SklearnKDTreeAdapter,
)

try:
__version__ = get_distribution(__name__).version
except DistributionNotFound: # pragma: no cover
# package is not installed
pass

__all__ = [
'IndexAdapter',
'IndexRegistry',
'SklearnBallTreeAdapter',
'SklearnGeoBallTreeAdapter',
'SklearnKDTreeAdapter',
'S2PointTreeAdapter',
'XoakAccessor',
]
101 changes: 101 additions & 0 deletions src/xoak/tree_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

import numpy as np

try:
from xarray.indexes.nd_point_index import TreeAdapter # type: ignore
except ImportError:

class TreeAdapter: ...


if TYPE_CHECKING:
import pys2index
import sklearn.neighbors


class S2PointTreeAdapter(TreeAdapter):
""":py:class:`pys2index.S2PointIndex` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""

_s2point_index: pys2index.S2PointIndex

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from pys2index import S2PointIndex

self._s2point_index = S2PointIndex(points)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._s2point_index.query(points)

def equals(self, other: S2PointTreeAdapter) -> bool:
return np.array_equal(
self._s2point_index.get_cell_ids(), other._s2point_index.get_cell_ids()
)


class SklearnKDTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""

_kdtree: sklearn.neighbors.KDTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import KDTree

self._kdtree = KDTree(points, **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._kdtree.query(points)

def equals(self, other: SklearnKDTreeAdapter) -> bool:
return np.array_equal(self._kdtree.data, other._kdtree.data)


class SklearnBallTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.BallTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""

_balltree: sklearn.neighbors.BallTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import BallTree

self._balltree = BallTree(points, **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(points)

def equals(self, other: SklearnBallTreeAdapter) -> bool:
return np.array_equal(self._balltree.data, other._balltree.data)


class SklearnGeoBallTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.BallTree` adapter for
:py:class:`~xarray.indexes.NDPointIndex`, using the 'haversine' metric.

It can be used for indexing a set of latitude / longitude points.

When building the index, the coordinates must be given in the latitude,
longitude order.

Latitude and longitude values must be given in degrees for both index and
query points (those values are converted in radians by this adapter).

"""

_balltree: sklearn.neighbors.BallTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import BallTree

opts = dict(options)
opts.update({'metric': 'haversine'})

self._balltree = BallTree(np.deg2rad(points), **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(np.deg2rad(points))

def equals(self, other: SklearnGeoBallTreeAdapter) -> bool:
return np.array_equal(self._balltree.data, other._balltree.data)
Loading