diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 284a0ae..2e7b50e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.3.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -8,18 +8,13 @@ repos: - id: check-yaml - id: double-quote-string-fixer - - repo: https://github.com/ambv/black - rev: 20.8b1 + - repo: https://github.com/psf/black + rev: 25.1.0 hooks: - id: black args: ["--line-length", "100", "--skip-string-normalization"] - - repo: https://gitlab.com/PyCQA/flake8 - rev: 3.8.4 - hooks: - - id: flake8 - - repo: https://github.com/PyCQA/isort - rev: 5.6.4 + rev: 6.0.1 hooks: - id: isort diff --git a/src/xoak/__init__.py b/src/xoak/__init__.py index a91b358..48ff239 100644 --- a/src/xoak/__init__.py +++ b/src/xoak/__init__.py @@ -1,10 +1,26 @@ from pkg_resources import DistributionNotFound, get_distribution -from .accessor import XoakAccessor -from .index import IndexAdapter, IndexRegistry +from xoak.accessor import XoakAccessor +from xoak.index import IndexAdapter, IndexRegistry +from xoak.tree_adapters import ( + S2PointTreeAdapter, + SklearnBallTreeAdapter, + SklearnGeoBallTreeAdapter, + SklearnKDTreeAdapter, +) try: __version__ = get_distribution(__name__).version except DistributionNotFound: # pragma: no cover # package is not installed pass + +__all__ = [ + 'IndexAdapter', + 'IndexRegistry', + 'SklearnBallTreeAdapter', + 'SklearnGeoBallTreeAdapter', + 'SklearnKDTreeAdapter', + 'S2PointTreeAdapter', + 'XoakAccessor', +] diff --git a/src/xoak/tree_adapters.py b/src/xoak/tree_adapters.py new file mode 100644 index 0000000..293fa66 --- /dev/null +++ b/src/xoak/tree_adapters.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +import numpy as np + +try: + from xarray.indexes.nd_point_index import TreeAdapter # type: ignore +except ImportError: + + class TreeAdapter: ... + + +if TYPE_CHECKING: + import pys2index + import sklearn.neighbors + + +class S2PointTreeAdapter(TreeAdapter): + """:py:class:`pys2index.S2PointIndex` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" + + _s2point_index: pys2index.S2PointIndex + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from pys2index import S2PointIndex + + self._s2point_index = S2PointIndex(points) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._s2point_index.query(points) + + def equals(self, other: S2PointTreeAdapter) -> bool: + return np.array_equal( + self._s2point_index.get_cell_ids(), other._s2point_index.get_cell_ids() + ) + + +class SklearnKDTreeAdapter(TreeAdapter): + """:py:class:`sklearn.neighbors.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" + + _kdtree: sklearn.neighbors.KDTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from sklearn.neighbors import KDTree + + self._kdtree = KDTree(points, **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._kdtree.query(points) + + def equals(self, other: SklearnKDTreeAdapter) -> bool: + return np.array_equal(self._kdtree.data, other._kdtree.data) + + +class SklearnBallTreeAdapter(TreeAdapter): + """:py:class:`sklearn.neighbors.BallTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" + + _balltree: sklearn.neighbors.BallTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from sklearn.neighbors import BallTree + + self._balltree = BallTree(points, **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._balltree.query(points) + + def equals(self, other: SklearnBallTreeAdapter) -> bool: + return np.array_equal(self._balltree.data, other._balltree.data) + + +class SklearnGeoBallTreeAdapter(TreeAdapter): + """:py:class:`sklearn.neighbors.BallTree` adapter for + :py:class:`~xarray.indexes.NDPointIndex`, using the 'haversine' metric. + + It can be used for indexing a set of latitude / longitude points. + + When building the index, the coordinates must be given in the latitude, + longitude order. + + Latitude and longitude values must be given in degrees for both index and + query points (those values are converted in radians by this adapter). + + """ + + _balltree: sklearn.neighbors.BallTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from sklearn.neighbors import BallTree + + opts = dict(options) + opts.update({'metric': 'haversine'}) + + self._balltree = BallTree(np.deg2rad(points), **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._balltree.query(np.deg2rad(points)) + + def equals(self, other: SklearnGeoBallTreeAdapter) -> bool: + return np.array_equal(self._balltree.data, other._balltree.data)