Skip to content

Commit f70fa86

Browse files
authored
Subclasses for Arrayish Objects (#120)
* ➕ dep-add(optype): for precise type definitions * ✨ feat(arrayish): mixins for dunder methods, to support lax and numpy operations on arrayish objects * 📝 docs(arrayish): add module to docs Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent 7d8f599 commit f70fa86

File tree

15 files changed

+3199
-34
lines changed

15 files changed

+3199
-34
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ repos:
8181
- id: mypy
8282
files: src
8383
additional_dependencies:
84-
- pytest
84+
- optype
8585
exclude: |
8686
(?x)^(
8787
src/quaxed/lax/__init__.py|
@@ -111,4 +111,4 @@ repos:
111111
name: Disallow improper capitalization
112112
language: pygrep
113113
entry: PyBind|Numpy|Cmake|CCache|Github|PyTest
114-
exclude: .pre-commit-config.yaml
114+
exclude: .pre-commit-config.yaml|src/quaxed/experimental/arrayish.py|src/quaxed/experimental/_arrayish

docs/api/arrayish.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# quaxed.experimental.arrayish
2+
3+
::: quaxed.experimental.arrayish

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,4 @@ nav:
102102
- "api/scipy.md"
103103
- "api/array_api.md"
104104
- "api/lax.md"
105+
- "api/arrayish.md"

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
"jax>=0.4.3",
2727
"jaxlib>=0.4.3",
2828
"jaxtyping>=0.2.34",
29+
"optype>=0.8.0",
2930
"plum-dispatch>=2.5.2",
3031
"quax>=0.0.5",
31-
]
32+
]
3233

3334
[project.urls]
3435
Homepage = "https://github.com/GalacticDynamics/quaxed"

src/quaxed/__init__.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,40 @@
1-
"""Quaxified `jax.scipy`.
1+
"""Pre-`quaxify`ed jax and related libraries.
22
3-
This module wraps the functions in `jax.lax` with `quax.quaxify`. The wrapping
4-
happens dynamically through a module-level ``__dir__`` and ``__getattr__``. The
5-
list of available functions is in ``__all__`` and documented in the `jax.lax`
6-
library.
7-
8-
In addition the following modules are supported:
9-
10-
- `quaxed.lax.linalg`
11-
12-
The contents of these modules are likewise dynamically wrapped with
13-
`quax.quaxify` and their contents is listed in their respective ``__all__`` and
14-
documented in their respective libraries.
15-
16-
If a function is missing, please file an Issue.
3+
`quax` is JAX + multiple dispatch + custom array-ish objects. `quaxed` is a
4+
drop-in replacement for many JAX and related libraries that applies
5+
`quax.quaxify` to the original JAX functions, enabling custom array-ish objects
6+
to be used with those functions, not only jax arrays.
177
188
"""
19-
# pylint: disable=C0415,W0621
20-
21-
from __future__ import annotations
229

23-
from typing import TYPE_CHECKING
24-
25-
from . import _jax, lax, numpy, scipy
26-
from ._jax import *
10+
__all__ = [
11+
# Modules
12+
"lax",
13+
"numpy",
14+
"scipy",
15+
"experimental",
16+
# Jax functions
17+
"device_put",
18+
"grad",
19+
"hessian",
20+
"jacfwd",
21+
"jacrev",
22+
"value_and_grad",
23+
]
24+
25+
from typing import TYPE_CHECKING, Any
26+
27+
from . import experimental, lax, numpy, scipy
28+
from ._jax import device_put, grad, hessian, jacfwd, jacrev, value_and_grad
2729
from ._setup import JAX_VERSION
2830
from ._version import version as __version__ # noqa: F401
2931

30-
__all__ = ["lax", "numpy", "scipy"]
31-
__all__ += _jax.__all__
32-
3332
if JAX_VERSION < (0, 4, 32):
3433
from . import array_api
3534

3635
__all__ += ["array_api"]
3736

3837

39-
if TYPE_CHECKING:
40-
from typing import Any
41-
42-
4338
def __getattr__(name: str) -> Any: # TODO: fuller annotation
4439
"""Forward all other attribute accesses to Quaxified JAX."""
4540
import sys
@@ -59,4 +54,4 @@ def __getattr__(name: str) -> Any: # TODO: fuller annotation
5954

6055

6156
# Clean up the namespace
62-
del TYPE_CHECKING
57+
del TYPE_CHECKING, Any
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Experimental modules."""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Arrayish."""
2+
3+
__all__: list[str] = []

0 commit comments

Comments
 (0)