1818 "LaxMulMixin" , "NumpyMulMixin" , # __mul__
1919 "LaxRMulMixin" , "NumpyRMulMixin" , # __rmul__
2020 # ---- matmul -----
21+ "LaxBothMatMulMixin" , "NumpyBothMatMulMixin" ,
2122 "LaxMatMulMixin" , "NumpyMatMulMixin" , # __matmul__
22- # "LaxRMatMulMixin", "NumpyRMatMulMixin", # __rmatmul__
23+ "LaxRMatMulMixin" , "NumpyRMatMulMixin" , # __rmatmul__
2324 # ----- truediv -----
2425 "LaxBothTrueDivMixin" , "NumpyBothTrueDivMixin" ,
2526 "LaxTrueDivMixin" , "NumpyTrueDivMixin" , # __truediv__
@@ -516,6 +517,76 @@ def __matmul__(self, other: T) -> R:
516517 return qnp .matmul (self , other )
517518
518519
520+ # -------------------------------------
521+
522+
523+ class LaxRMatMulMixin (Generic [T , R ]):
524+ """Mixin for ``__rmatmul__`` method using quaxified `jax.lax.matmul`.
525+
526+ Examples
527+ --------
528+ >>> from typing import Any
529+ >>> import jax
530+ >>> import jax.numpy as jnp
531+ >>> from jaxtyping import Array
532+ >>> from quax import ArrayValue
533+
534+ >>> class MyArray(ArrayValue, LaxRMatMulMixin[Any, Array]):
535+ ... value: Array
536+ ... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
537+ ... def materialise(self): return self.value
538+
539+ >>> x = MyArray(jnp.array([[1, 2], [3, 4]]))
540+ >>> y = jnp.array([[5, 6], [7, 8]])
541+ >>> y @ x
542+ Array([[23, 34],
543+ [31, 46]], dtype=int32)
544+
545+ """ # noqa: E501
546+
547+ def __rmatmul__ (self , other : T ) -> R :
548+ return qlax .dot (other , self ) # TODO: is this the right operator?
549+
550+
551+ class NumpyRMatMulMixin (Generic [T , R ]):
552+ """Mixin for ``__rmatmul__`` method using quaxified `jax.numpy.matmul`.
553+
554+ Examples
555+ --------
556+ >>> from typing import Any
557+ >>> import jax
558+ >>> import jax.numpy as jnp
559+ >>> from jaxtyping import Array
560+ >>> from quax import ArrayValue
561+
562+ >>> class MyArray(ArrayValue, NumpyRMatMulMixin[Any, Array]):
563+ ... value: Array
564+ ... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
565+ ... def materialise(self): return self.value
566+
567+ >>> x = MyArray(jnp.array([[1, 2], [3, 4]]))
568+ >>> y = jnp.array([[5, 6], [7, 8]])
569+ >>> y @ x
570+ Array([[23, 34],
571+ [31, 46]], dtype=int32)
572+
573+ """ # noqa: E501
574+
575+ def __rmatmul__ (self , other : T ) -> R :
576+ return qnp .matmul (other , self )
577+
578+
579+ # -------------------------------------
580+
581+
582+ class LaxBothMatMulMixin (LaxMatMulMixin [T , R ], LaxRMatMulMixin [T , R ]):
583+ pass
584+
585+
586+ class NumpyBothMatMulMixin (NumpyMatMulMixin [T , R ], NumpyRMatMulMixin [T , R ]):
587+ pass
588+
589+
519590# ===============================================
520591# Float Division
521592
@@ -1759,7 +1830,7 @@ class LaxMathMixin(
17591830 LaxBothAddMixin [T , R ], # __add__, __radd__
17601831 LaxBothSubMixin [T , R ], # __sub__, __rsub__
17611832 LaxBothMulMixin [T , R ], # __mul__, __rmul__
1762- LaxMatMulMixin [T , R ], # __matmul__
1833+ LaxBothMatMulMixin [T , R ], # __matmul__, __rmatmul__
17631834 LaxBothTrueDivMixin [T , R ], # __truediv__, __rtruediv__
17641835 LaxBothFloorDivMixin [T , R ], # __floordiv__, __rfloordiv__
17651836 LaxBothModMixin [T , R ], # __mod__, __rmod__
@@ -1773,7 +1844,7 @@ class NumpyMathMixin(
17731844 NumpyBothAddMixin [T , R ], # __add__, __radd__
17741845 NumpyBothSubMixin [T , R ], # __sub__, __rsub__
17751846 NumpyBothMulMixin [T , R ], # __mul__, __rmul__
1776- NumpyMatMulMixin [T , R ], # __matmul__
1847+ NumpyBothMatMulMixin [T , R ], # __matmul__, __rmatmul__
17771848 NumpyBothTrueDivMixin [T , R ], # __truediv__, __rtruediv__
17781849 NumpyBothFloorDivMixin [T , R ], # __floordiv__, __rfloordiv__
17791850 NumpyBothModMixin [T , R ], # __mod__, __rmod__
0 commit comments