-
-
Notifications
You must be signed in to change notification settings - Fork 6
fix: jax v0.7.2+ LiteralArray / TypedNdArray #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
patrick-kidger
merged 15 commits into
patrick-kidger:main
from
nstarman:deps-bump-jaxtyping
Oct 12, 2025
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
3ecbfb3
deps: bump jaxtyping v0.3.3+
nstarman a3d4bd2
deps: bump equinox to v0.12.2+
nstarman f0b751b
types: annotate core
nstarman aff0973
style: remove extraneous import
nstarman 3ab4b05
refactor: consolidate value default check
nstarman 726b583
refactor: new rule definition
nstarman b892033
deps: add ipykernel to dev
nstarman 09ab2b5
refactor: use `_src.core` import
nstarman c4dd13e
types: annotate full_lower
nstarman 7d1376e
types: annotate _is_value
nstarman a5163f9
types: annotate tracer
nstarman 8699aeb
style: improve func name for error tracing
nstarman b9c4cc8
types: annotate `__get__`
nstarman d32f368
deps: bump equinox
nstarman c509a6a
deps: protect jax
nstarman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,11 @@ | |
| import functools as ft | ||
| import itertools as it | ||
| from collections.abc import Callable, Sequence | ||
| from typing import Any, cast, Generic, TypeGuard, TypeVar, Union | ||
| from typing import Any, cast, Generic, overload, TypeGuard, TypeVar, Union | ||
|
|
||
| import equinox as eqx | ||
| import jax | ||
| import jax._src | ||
| import jax.core as core | ||
| import jax._src.core as core | ||
| import jax.extend.core as jexc | ||
| import jax.extend.linear_util as lu | ||
| import jax.numpy as jnp | ||
|
|
@@ -19,6 +18,7 @@ | |
| from ._compat import jit_p | ||
|
|
||
|
|
||
| T = TypeVar("T") | ||
| CT = TypeVar("CT", bound=Callable) | ||
|
|
||
| # | ||
|
|
@@ -65,12 +65,12 @@ def _register(rule: CT) -> CT: | |
| existing_rule = _rules[primitive] # pyright: ignore | ||
| except KeyError: | ||
|
|
||
| def existing_rule(): | ||
| assert False | ||
| def new_rule(): | ||
| raise NotImplementedError("Abstract primitive") # pragma: no cover | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| existing_rule.__name__ = f"{primitive}_dispatcher" | ||
| existing_rule.__qualname__ = f"{primitive}_dispatcher" | ||
| existing_rule = plum.Dispatcher().abstract(existing_rule) | ||
| new_rule.__name__ = f"{primitive}_dispatcher" | ||
| new_rule.__qualname__ = f"{primitive}_dispatcher" | ||
| existing_rule = plum.Dispatcher().abstract(new_rule) | ||
|
|
||
| _rules[primitive] = existing_rule | ||
| existing_rule.dispatch(rule, precedence=precedence) | ||
|
|
@@ -93,10 +93,10 @@ def __init__(self, trace: "_QuaxTrace", value: "Value") -> None: | |
| self.value = value | ||
|
|
||
| @property | ||
| def aval(self): | ||
| def aval(self) -> core.AbstractValue: | ||
| return self.value.aval() | ||
|
|
||
| def full_lower(self): | ||
| def full_lower(self) -> Union[ArrayLike, "_QuaxTracer"]: | ||
| if isinstance(self.value, _DenseArrayValue): | ||
| return core.full_lower(self.value.array) # pyright: ignore[reportAttributeAccessIssue] | ||
| else: | ||
|
|
@@ -110,9 +110,7 @@ def _default_process( | |
| for x in values: | ||
| if isinstance(x, Value): | ||
| x_default = type(x).default | ||
| if x_default is Value.default: | ||
| pass | ||
| else: | ||
| if x_default is not Value.default: | ||
| defaults.add(x_default) | ||
| elif eqx.is_array_like(x): | ||
| # Ignore any unwrapped _DenseArrayValues | ||
|
|
@@ -286,11 +284,15 @@ def _custom_jvp_jvp_wrap(tag, in_treedef, *in_primals_and_tangents): | |
| # | ||
|
|
||
|
|
||
| def _wrap_tracer(trace: _QuaxTrace, x): | ||
| if _is_value(x): | ||
| return _QuaxTracer(trace, x) | ||
| else: | ||
| return x | ||
| # Any -> Any so overloads carry the public types. mypy can’t prove the else | ||
| # branch is T (since T may be Value). To type the body, use Union[Value, T] | ||
| # + cast(T, x), or constrain T to exclude Value. | ||
| @overload | ||
| def _wrap_tracer(trace: _QuaxTrace, x: "Value") -> _QuaxTracer: ... | ||
| @overload | ||
| def _wrap_tracer(trace: _QuaxTrace, x: T) -> T: ... | ||
| def _wrap_tracer(trace: _QuaxTrace, x: Any) -> Any: | ||
| return _QuaxTracer(trace, x) if _is_value(x) else x | ||
|
|
||
|
|
||
| def _unwrap_tracer(trace, x): | ||
|
|
@@ -332,9 +334,13 @@ def __call__(self, *args, **kwargs): | |
| out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out) | ||
| return out | ||
|
|
||
| def __get__(self, instance: object | None, owner: Any): | ||
| def __get__( | ||
| self, instance: object | None, owner: Any | ||
| ) -> Union["_Quaxify[CT]", eqx.Partial["_Quaxify[CT]"]]: | ||
| # Getting from a class | ||
| if instance is None: | ||
| return self | ||
| # Getting from an instance | ||
| return eqx.Partial(self, instance) | ||
|
|
||
|
|
||
|
|
@@ -487,7 +493,7 @@ def materialise(self) -> Any: | |
| """ | ||
|
|
||
|
|
||
| def _is_value(x) -> TypeGuard[Value]: | ||
| def _is_value(x: object) -> TypeGuard[Value]: | ||
| return isinstance(x, Value) | ||
|
|
||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, why did this sneak in as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess dev on Jupyter.