Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ classifiers = [
]
urls = {repository = "https://github.com/patrick-kidger/quax" }
dependencies = [
"jax>=0.5.3",
"jaxtyping>=0.3.1",
"equinox>=0.12.1",
"jax>=0.5.3,!=0.7.0,!=0.7.1",
"jaxtyping>=0.3.3",
"equinox>=0.13.2",
"plum-dispatch>=2.2.1",
]

Expand All @@ -36,6 +36,7 @@ dev = [
"beartype>=0.20.2",
"pytest>=8.3.5",
"pytest-env>=1.1.5",
"ipykernel>=6.30.1",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, why did this sneak in as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess dev on Jupyter.

]
docs = [
"hippogriffe==0.2.0",
Expand Down
46 changes: 26 additions & 20 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import functools as ft
import itertools as it
from collections.abc import Callable, Sequence
from typing import Any, cast, Generic, TypeGuard, TypeVar, Union
from typing import Any, cast, Generic, overload, TypeGuard, TypeVar, Union

import equinox as eqx
import jax
import jax._src
import jax.core as core
import jax._src.core as core
import jax.extend.core as jexc
import jax.extend.linear_util as lu
import jax.numpy as jnp
Expand All @@ -19,6 +18,7 @@
from ._compat import jit_p


T = TypeVar("T")
CT = TypeVar("CT", bound=Callable)

#
Expand Down Expand Up @@ -65,12 +65,12 @@ def _register(rule: CT) -> CT:
existing_rule = _rules[primitive] # pyright: ignore
except KeyError:

def existing_rule():
assert False
def new_rule():
raise NotImplementedError("Abstract primitive") # pragma: no cover
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pragma?


existing_rule.__name__ = f"{primitive}_dispatcher"
existing_rule.__qualname__ = f"{primitive}_dispatcher"
existing_rule = plum.Dispatcher().abstract(existing_rule)
new_rule.__name__ = f"{primitive}_dispatcher"
new_rule.__qualname__ = f"{primitive}_dispatcher"
existing_rule = plum.Dispatcher().abstract(new_rule)

_rules[primitive] = existing_rule
existing_rule.dispatch(rule, precedence=precedence)
Expand All @@ -93,10 +93,10 @@ def __init__(self, trace: "_QuaxTrace", value: "Value") -> None:
self.value = value

@property
def aval(self):
def aval(self) -> core.AbstractValue:
return self.value.aval()

def full_lower(self):
def full_lower(self) -> Union[ArrayLike, "_QuaxTracer"]:
if isinstance(self.value, _DenseArrayValue):
return core.full_lower(self.value.array) # pyright: ignore[reportAttributeAccessIssue]
else:
Expand All @@ -110,9 +110,7 @@ def _default_process(
for x in values:
if isinstance(x, Value):
x_default = type(x).default
if x_default is Value.default:
pass
else:
if x_default is not Value.default:
defaults.add(x_default)
elif eqx.is_array_like(x):
# Ignore any unwrapped _DenseArrayValues
Expand Down Expand Up @@ -286,11 +284,15 @@ def _custom_jvp_jvp_wrap(tag, in_treedef, *in_primals_and_tangents):
#


def _wrap_tracer(trace: _QuaxTrace, x):
if _is_value(x):
return _QuaxTracer(trace, x)
else:
return x
# Any -> Any so overloads carry the public types. mypy can’t prove the else
# branch is T (since T may be Value). To type the body, use Union[Value, T]
# + cast(T, x), or constrain T to exclude Value.
@overload
def _wrap_tracer(trace: _QuaxTrace, x: "Value") -> _QuaxTracer: ...
@overload
def _wrap_tracer(trace: _QuaxTrace, x: T) -> T: ...
def _wrap_tracer(trace: _QuaxTrace, x: Any) -> Any:
return _QuaxTracer(trace, x) if _is_value(x) else x


def _unwrap_tracer(trace, x):
Expand Down Expand Up @@ -332,9 +334,13 @@ def __call__(self, *args, **kwargs):
out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
return out

def __get__(self, instance: object | None, owner: Any):
def __get__(
self, instance: object | None, owner: Any
) -> Union["_Quaxify[CT]", eqx.Partial["_Quaxify[CT]"]]:
# Getting from a class
if instance is None:
return self
# Getting from an instance
return eqx.Partial(self, instance)


Expand Down Expand Up @@ -487,7 +493,7 @@ def materialise(self) -> Any:
"""


def _is_value(x) -> TypeGuard[Value]:
def _is_value(x: object) -> TypeGuard[Value]:
return isinstance(x, Value)


Expand Down
4 changes: 3 additions & 1 deletion quax/examples/zero/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _(value: Zero, *, broadcast_dimensions, shape, sharding=None) -> Zero:


@quax.register(lax.convert_element_type_p)
def _(value: Zero, *, new_dtype, weak_type, sharding=None) -> Zero:
def convert_element_type_zero(
value: Zero, *, new_dtype, weak_type, sharding=None
) -> Zero:
# sharding was added around JAX 0.4.31, it seems.
del weak_type, sharding
return Zero(value.shape, new_dtype)
Expand Down
7 changes: 2 additions & 5 deletions tests/myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,8 @@ def conv_general_dilated_p(


@register(lax.convert_element_type_p)
def convert_element_type_p(operand: MyArray, **kw: Any) -> MyArray:
return replace(
operand,
array=lax.convert_element_type_p.bind(operand.array, **kw),
)
def convert_element_type_myarray(operand: MyArray, **kw: Any) -> MyArray:
return replace(operand, array=lax.convert_element_type_p.bind(operand.array, **kw))


# ==============================================================================
Expand Down