1- """Quaxified `jax.scipy` .
1+ """Pre-`quaxify`ed jax and related libraries .
22
3- This module wraps the functions in `jax.lax` with `quax.quaxify`. The wrapping
4- happens dynamically through a module-level ``__dir__`` and ``__getattr__``. The
5- list of available functions is in ``__all__`` and documented in the `jax.lax`
6- library.
7-
8- In addition the following modules are supported:
9-
10- - `quaxed.lax.linalg`
11-
12- The contents of these modules are likewise dynamically wrapped with
13- `quax.quaxify` and their contents is listed in their respective ``__all__`` and
14- documented in their respective libraries.
15-
16- If a function is missing, please file an Issue.
3+ `quax` is JAX + multiple dispatch + custom array-ish objects. `quaxed` is a
4+ drop-in replacement for many JAX and related libraries that applies
5+ `quax.quaxify` to the original JAX functions, enabling custom array-ish objects
6+ to be used with those functions, not only jax arrays.
177
188"""
19- # pylint: disable=C0415,W0621
20-
21- from __future__ import annotations
229
23- from typing import TYPE_CHECKING
24-
25- from . import _jax , lax , numpy , scipy
26- from ._jax import *
10+ __all__ = [
11+ # Modules
12+ "lax" ,
13+ "numpy" ,
14+ "scipy" ,
15+ "experimental" ,
16+ # Jax functions
17+ "device_put" ,
18+ "grad" ,
19+ "hessian" ,
20+ "jacfwd" ,
21+ "jacrev" ,
22+ "value_and_grad" ,
23+ ]
24+
25+ from typing import TYPE_CHECKING , Any
26+
27+ from . import experimental , lax , numpy , scipy
28+ from ._jax import device_put , grad , hessian , jacfwd , jacrev , value_and_grad
2729from ._setup import JAX_VERSION
2830from ._version import version as __version__ # noqa: F401
2931
30- __all__ = ["lax" , "numpy" , "scipy" ]
31- __all__ += _jax .__all__
32-
3332if JAX_VERSION < (0 , 4 , 32 ):
3433 from . import array_api
3534
3635 __all__ += ["array_api" ]
3736
3837
39- if TYPE_CHECKING :
40- from typing import Any
41-
42-
4338def __getattr__ (name : str ) -> Any : # TODO: fuller annotation
4439 """Forward all other attribute accesses to Quaxified JAX."""
4540 import sys
@@ -59,4 +54,4 @@ def __getattr__(name: str) -> Any: # TODO: fuller annotation
5954
6055
6156# Clean up the namespace
62- del TYPE_CHECKING
57+ del TYPE_CHECKING , Any
0 commit comments