Skip to content

Commit ccd39b8

Browse files
authored
✨ feat(arrayish): rmatmul (#122)
Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent eb24868 commit ccd39b8

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

src/quaxed/experimental/_arrayish/binary.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
"LaxMulMixin", "NumpyMulMixin", # __mul__
1919
"LaxRMulMixin", "NumpyRMulMixin", # __rmul__
2020
# ---- matmul -----
21+
"LaxBothMatMulMixin", "NumpyBothMatMulMixin",
2122
"LaxMatMulMixin", "NumpyMatMulMixin", # __matmul__
22-
# "LaxRMatMulMixin", "NumpyRMatMulMixin", # __rmatmul__
23+
"LaxRMatMulMixin", "NumpyRMatMulMixin", # __rmatmul__
2324
# ----- truediv -----
2425
"LaxBothTrueDivMixin", "NumpyBothTrueDivMixin",
2526
"LaxTrueDivMixin", "NumpyTrueDivMixin", # __truediv__
@@ -516,6 +517,76 @@ def __matmul__(self, other: T) -> R:
516517
return qnp.matmul(self, other)
517518

518519

520+
# -------------------------------------
521+
522+
523+
class LaxRMatMulMixin(Generic[T, R]):
524+
"""Mixin for ``__rmatmul__`` method using quaxified `jax.lax.matmul`.
525+
526+
Examples
527+
--------
528+
>>> from typing import Any
529+
>>> import jax
530+
>>> import jax.numpy as jnp
531+
>>> from jaxtyping import Array
532+
>>> from quax import ArrayValue
533+
534+
>>> class MyArray(ArrayValue, LaxRMatMulMixin[Any, Array]):
535+
... value: Array
536+
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
537+
... def materialise(self): return self.value
538+
539+
>>> x = MyArray(jnp.array([[1, 2], [3, 4]]))
540+
>>> y = jnp.array([[5, 6], [7, 8]])
541+
>>> y @ x
542+
Array([[23, 34],
543+
[31, 46]], dtype=int32)
544+
545+
""" # noqa: E501
546+
547+
def __rmatmul__(self, other: T) -> R:
548+
return qlax.dot(other, self) # TODO: is this the right operator?
549+
550+
551+
class NumpyRMatMulMixin(Generic[T, R]):
552+
"""Mixin for ``__rmatmul__`` method using quaxified `jax.numpy.matmul`.
553+
554+
Examples
555+
--------
556+
>>> from typing import Any
557+
>>> import jax
558+
>>> import jax.numpy as jnp
559+
>>> from jaxtyping import Array
560+
>>> from quax import ArrayValue
561+
562+
>>> class MyArray(ArrayValue, NumpyRMatMulMixin[Any, Array]):
563+
... value: Array
564+
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
565+
... def materialise(self): return self.value
566+
567+
>>> x = MyArray(jnp.array([[1, 2], [3, 4]]))
568+
>>> y = jnp.array([[5, 6], [7, 8]])
569+
>>> y @ x
570+
Array([[23, 34],
571+
[31, 46]], dtype=int32)
572+
573+
""" # noqa: E501
574+
575+
def __rmatmul__(self, other: T) -> R:
576+
return qnp.matmul(other, self)
577+
578+
579+
# -------------------------------------
580+
581+
582+
class LaxBothMatMulMixin(LaxMatMulMixin[T, R], LaxRMatMulMixin[T, R]):
583+
pass
584+
585+
586+
class NumpyBothMatMulMixin(NumpyMatMulMixin[T, R], NumpyRMatMulMixin[T, R]):
587+
pass
588+
589+
519590
# ===============================================
520591
# Float Division
521592

@@ -1759,7 +1830,7 @@ class LaxMathMixin(
17591830
LaxBothAddMixin[T, R], # __add__, __radd__
17601831
LaxBothSubMixin[T, R], # __sub__, __rsub__
17611832
LaxBothMulMixin[T, R], # __mul__, __rmul__
1762-
LaxMatMulMixin[T, R], # __matmul__
1833+
LaxBothMatMulMixin[T, R], # __matmul__, __rmatmul__
17631834
LaxBothTrueDivMixin[T, R], # __truediv__, __rtruediv__
17641835
LaxBothFloorDivMixin[T, R], # __floordiv__, __rfloordiv__
17651836
LaxBothModMixin[T, R], # __mod__, __rmod__
@@ -1773,7 +1844,7 @@ class NumpyMathMixin(
17731844
NumpyBothAddMixin[T, R], # __add__, __radd__
17741845
NumpyBothSubMixin[T, R], # __sub__, __rsub__
17751846
NumpyBothMulMixin[T, R], # __mul__, __rmul__
1776-
NumpyMatMulMixin[T, R], # __matmul__
1847+
NumpyBothMatMulMixin[T, R], # __matmul__, __rmatmul__
17771848
NumpyBothTrueDivMixin[T, R], # __truediv__, __rtruediv__
17781849
NumpyBothFloorDivMixin[T, R], # __floordiv__, __rfloordiv__
17791850
NumpyBothModMixin[T, R], # __mod__, __rmod__

src/quaxed/experimental/arrayish.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
"LaxMulMixin", "NumpyMulMixin", # __mul__
3434
"LaxRMulMixin", "NumpyRMulMixin", # __rmul__
3535
# ---- matmul -----
36+
"LaxBothMatMulMixin", "NumpyBothMatMulMixin",
3637
"LaxMatMulMixin", "NumpyMatMulMixin", # __matmul__
37-
# "LaxRMatMulMixin", "NumpyRMatMulMixin", # __rmatmul__
38+
"LaxRMatMulMixin", "NumpyRMatMulMixin", # __rmatmul__
3839
# ----- truediv -----
3940
"LaxBothTrueDivMixin", "NumpyBothTrueDivMixin",
4041
"LaxTrueDivMixin", "NumpyTrueDivMixin", # __truediv__
@@ -109,6 +110,7 @@
109110
LaxBothAndMixin,
110111
LaxBothFloorDivMixin,
111112
LaxBothLShiftMixin,
113+
LaxBothMatMulMixin,
112114
LaxBothModMixin,
113115
LaxBothMulMixin,
114116
LaxBothOrMixin,
@@ -129,6 +131,7 @@
129131
LaxRAndMixin,
130132
LaxRFloorDivMixin,
131133
LaxRLShiftMixin,
134+
LaxRMatMulMixin,
132135
LaxRModMixin,
133136
LaxRMulMixin,
134137
LaxROrMixin,
@@ -150,6 +153,7 @@
150153
NumpyBothDivModMixin,
151154
NumpyBothFloorDivMixin,
152155
NumpyBothLShiftMixin,
156+
NumpyBothMatMulMixin,
153157
NumpyBothModMixin,
154158
NumpyBothMulMixin,
155159
NumpyBothOrMixin,
@@ -172,6 +176,7 @@
172176
NumpyRDivModMixin,
173177
NumpyRFloorDivMixin,
174178
NumpyRLShiftMixin,
179+
NumpyRMatMulMixin,
175180
NumpyRModMixin,
176181
NumpyRMulMixin,
177182
NumpyROrMixin,

0 commit comments

Comments
 (0)