Skip to content

Commit 89b4e78

Browse files
authored
πŸ› fix: vectorize, cond (#141)
* βœ… test: mark_todo * βœ… test: consolidate xconv * βœ… test: restructure lax test * ⬆️ dep-bump(pytest-github-actions-annotate-failures): v0.3.0+ * ⬆️ dep-bump(constraints): update constraint * πŸ› fix: cond_quax * πŸ› fix: vectorize Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent b12f731 commit 89b4e78

File tree

6 files changed

+588
-486
lines changed

6 files changed

+588
-486
lines changed

β€Žpyproject.tomlβ€Ž

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"pytest >= 8.3",
6969
"pytest-cov >= 3",
7070
"pytest-env>=1.1.5",
71-
"pytest-github-actions-annotate-failures", # only applies to GH Actions
71+
"pytest-github-actions-annotate-failures>=0.3.0", # only applies to GH Actions
7272
"sybil >= 7.1.0",
7373
]
7474

@@ -200,6 +200,9 @@ constraint-dependencies = [
200200
"decorator>=5.1.1",
201201
"matplotlib>=3.7.1",
202202
"matplotlib-inline>=0.1.6",
203+
"opt-einsum>=3.2.1",
204+
"pickleshare>=0.7.5",
205+
"psutil>=5.9.0",
203206
"pyparsing>=3.0.0",
204207
"pyzmq>=25.0",
205208
]

β€Žsrc/quaxed/lax/_patch.pyβ€Ž

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from typing import Any
1212

13+
import jax
14+
import jax.extend.core as jexc
15+
import jax.tree_util as jtu
1316
import quax
1417
from jax import lax
1518
from jaxtyping import Array, ArrayLike
@@ -29,3 +32,49 @@ def regularized_incomplete_beta_p(
2932
def scan_p(*args: ArrayLike, **kw: Any) -> Array:
3033
"""Patched implementation of lax.map."""
3134
return lax.scan_p.bind(*args, **kw)
35+
36+
37+
# =========================================================
38+
# https://github.com/patrick-kidger/quax/pull/64
39+
40+
_sentinel = object()
41+
42+
43+
@quax.register(lax.cond_p) # type: ignore[misc]
44+
def cond_quax(
45+
index: ArrayLike,
46+
*args: quax.ArrayValue | ArrayLike,
47+
branches: tuple[Any, ...],
48+
linear: Any = _sentinel,
49+
branches_platforms: Any = _sentinel,
50+
) -> quax.ArrayValue:
51+
flat_args, in_tree = jtu.tree_flatten(args)
52+
53+
out_trees = []
54+
quax_branches = []
55+
for jaxpr in branches:
56+
57+
def flat_quax_call(flat_args: Any) -> Any:
58+
args = jtu.tree_unflatten(in_tree, flat_args)
59+
out = quax.quaxify(jexc.jaxpr_as_fun(jaxpr))(*args) # noqa: B023
60+
flat_out, out_tree = jtu.tree_flatten(out)
61+
out_trees.append(out_tree)
62+
return flat_out
63+
64+
quax_jaxpr = jax.make_jaxpr(flat_quax_call)(flat_args)
65+
quax_branches.append(quax_jaxpr)
66+
67+
if any(tree_outs_i != out_trees[0] for tree_outs_i in out_trees[1:]):
68+
msg = "all branches output must have the same pytree."
69+
raise TypeError(msg)
70+
71+
kwargs = {}
72+
if linear is not _sentinel:
73+
kwargs["linear"] = linear
74+
if branches_platforms is not _sentinel:
75+
kwargs["branches_platforms"] = branches_platforms
76+
77+
out_val = jax.lax.cond_p.bind(
78+
index, *flat_args, branches=tuple(quax_branches), **kwargs
79+
)
80+
return jtu.tree_unflatten(out_trees[0], out_val)

β€Žsrc/quaxed/numpy/_higher_order.pyβ€Ž

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
__all__ = ["vectorize"]
44

55
import functools
6+
import warnings
67
from collections.abc import Callable, Collection
78
from typing import Any, TypeVar
89

910
import equinox as eqx
1011
import jax
12+
from jax._src import config
1113
from jax._src.numpy.vectorize import (
1214
_apply_excluded,
1315
_check_output_dims,
@@ -26,7 +28,7 @@ def expand_dims(a: T, axis: int | tuple[int, ...]) -> T:
2628
return eqx.combine(expanded_dynamic, static)
2729

2830

29-
def vectorize( # noqa: C901
31+
def vectorize( # noqa: C901, PLR0915
3032
pyfunc: Callable[..., Any],
3133
*,
3234
excluded: Collection[int | str] = frozenset(),
@@ -119,7 +121,7 @@ def vectorize( # noqa: C901
119121
raise ValueError(msg)
120122

121123
@functools.wraps(pyfunc)
122-
def wrapped(*args: Any, **kwargs: Any) -> Any:
124+
def wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: C901, PLR0912, PLR0915
123125
error_context = (
124126
f"on vectorized function with excluded={excluded!r} and "
125127
f"signature={signature!r}"
@@ -148,9 +150,32 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
148150
args, input_core_dims, error_context
149151
)
150152

151-
checked_func = _check_output_dims(
152-
excluded_func, dim_sizes, output_core_dims, error_context
153-
)
153+
if output_core_dims is None:
154+
checked_func = excluded_func
155+
else:
156+
checked_func = _check_output_dims(
157+
excluded_func, dim_sizes, output_core_dims, error_context
158+
)
159+
160+
# Detect implicit rank promotion:
161+
if config.numpy_rank_promotion.value != "allow":
162+
ranks = [
163+
arg.ndim - len(core_dims)
164+
for arg, core_dims in zip(args, input_core_dims, strict=False)
165+
if arg.ndim != 0
166+
]
167+
if len(set(ranks)) > 1:
168+
msg = (
169+
f"operands with shapes {[arg.shape for arg in args]} require rank"
170+
f" promotion for jnp.vectorize function with signature {signature}."
171+
" Set the jax_numpy_rank_promotion config option to 'allow' to"
172+
" disable this message; for more information, see"
173+
" https://docs.jax.dev/en/latest/rank_promotion_warning.html."
174+
)
175+
if config.numpy_rank_promotion.value == "warn":
176+
warnings.warn(msg, stacklevel=1)
177+
elif config.numpy_rank_promotion.value == "raise":
178+
raise ValueError(msg)
154179

155180
# Rather than broadcasting all arguments to full broadcast shapes, prefer
156181
# expanding dimensions using vmap. By pushing broadcasting

β€Žtests/test_lax/test_jax.pyβ€Ž

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
xround = jnp.array([[1.1, 2.2], [3.3, 4.4]])
1919
conv_kernel = jnp.array([[[[1.0, 0.0], [0.0, -1.0]]]], dtype=float)
2020
xcomp = jnp.array([[5, 2], [7, 2]], dtype=float)
21+
xconv = jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4))
2122

2223

2324
@pytest.mark.parametrize(
@@ -27,7 +28,7 @@
2728
("acos", (xtrig,), {}),
2829
("acosh", (x,), {}),
2930
("add", (x, y), {}),
30-
pytest.param("after_all", (), {}, marks=pytest.mark.skip),
31+
pytest.param("after_all", (), {}, marks=mark_todo),
3132
("approx_max_k", (x, 2), {}),
3233
("approx_min_k", (x, 2), {}),
3334
("argmax", (x,), {"axis": 0, "index_dtype": int}),
@@ -57,19 +58,15 @@
5758
("broadcast_in_dim", (x, (1, 1, 2, 2), (2, 3)), {}),
5859
("broadcast_shapes", ((2, 3), (1, 3)), {}),
5960
("broadcast_to_rank", (x,), {"rank": 3}),
60-
pytest.param("broadcasted_iota", (), {}, marks=pytest.mark.skip),
61+
pytest.param("broadcasted_iota", (), {}, marks=mark_todo),
6162
("cbrt", (x,), {}),
6263
("ceil", (xround,), {}),
6364
("clamp", (2.0, x, 3.0), {}),
6465
("clz", (xbit,), {}),
6566
("collapse", (x, 1), {}),
6667
("concatenate", ((x, y), 0), {}),
6768
("conj", (xcomplex,), {}),
68-
(
69-
"conv",
70-
(jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)), conv_kernel),
71-
{"window_strides": (1, 1), "padding": "SAME"},
72-
),
69+
("conv", (xconv, conv_kernel), {"window_strides": (1, 1), "padding": "SAME"}),
7370
("convert_element_type", (x, jnp.int32), {}),
7471
(
7572
"conv_dimension_numbers",
@@ -78,25 +75,25 @@
7875
),
7976
(
8077
"conv_general_dilated",
81-
(jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)), conv_kernel),
78+
(xconv, conv_kernel),
8279
{"window_strides": (1, 1), "padding": "SAME"},
8380
),
84-
pytest.param("conv_general_dilated_local", (), {}, marks=pytest.mark.skip),
81+
pytest.param("conv_general_dilated_local", (), {}, marks=mark_todo),
8582
(
8683
"conv_general_dilated_patches",
87-
(jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)),),
84+
(xconv,),
8885
{"filter_shape": (2, 2), "window_strides": (1, 1), "padding": "VALID"},
8986
),
9087
(
9188
"conv_transpose",
92-
(jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)), conv_kernel),
89+
(xconv, conv_kernel),
9390
{
9491
"strides": (2, 2),
9592
"padding": "SAME",
9693
"dimension_numbers": ("NCHW", "OIHW", "NCHW"),
9794
},
9895
),
99-
pytest.param("conv_with_general_padding", (), {}, marks=pytest.mark.skip),
96+
pytest.param("conv_with_general_padding", (), {}, marks=mark_todo),
10097
("cos", (x,), {}),
10198
("cosh", (x,), {}),
10299
("cumlogsumexp", (x,), {"axis": 0}),
@@ -107,11 +104,11 @@
107104
("digamma", (xtrig,), {}),
108105
("div", (x, y), {}),
109106
("dot", (x, y), {}),
110-
pytest.param("dot_general", (), {}, marks=pytest.mark.skip),
111-
pytest.param("dynamic_index_in_dim", (), {}, marks=pytest.mark.skip),
107+
pytest.param("dot_general", (), {}, marks=mark_todo),
108+
pytest.param("dynamic_index_in_dim", (), {}, marks=mark_todo),
112109
("dynamic_slice", (x, (0, 0), (2, 2)), {}),
113-
pytest.param("dynamic_slice_in_dim", (), {}, marks=pytest.mark.skip),
114-
pytest.param("dynamic_update_index_in_dim", (), {}, marks=pytest.mark.skip),
110+
pytest.param("dynamic_slice_in_dim", (), {}, marks=mark_todo),
111+
pytest.param("dynamic_update_index_in_dim", (), {}, marks=mark_todo),
115112
("dynamic_update_slice", (x, y, (0, 0)), {}),
116113
("dynamic_update_slice_in_dim", (x, y, 0, 0), {}),
117114
("eq", (x, x), {}),
@@ -126,16 +123,16 @@
126123
("floor", (xround,), {}),
127124
("full", ((2, 2), 1.0), {}),
128125
("full_like", (x, 1.0), {}),
129-
pytest.param("gather", (), {}, marks=pytest.mark.skip),
126+
pytest.param("gather", (), {}, marks=mark_todo),
130127
("ge", (x, xcomp), {}),
131128
("gt", (x, xcomp), {}),
132129
("igamma", (1.0, xtrig), {}),
133130
("igammac", (1.0, xtrig), {}),
134131
("imag", (xcomplex,), {}),
135132
("index_in_dim", (x, 0, 0), {}),
136-
pytest.param("index_take", (), {}, marks=pytest.mark.skip),
133+
pytest.param("index_take", (), {}, marks=mark_todo),
137134
("integer_pow", (x, 2), {}),
138-
pytest.param("iota", (), {}, marks=pytest.mark.skip),
135+
pytest.param("iota", (), {}, marks=mark_todo),
139136
("is_finite", (x,), {}),
140137
("le", (x, xcomp), {}),
141138
("lgamma", (x,), {}),
@@ -149,28 +146,28 @@
149146
("ne", (x, xcomp), {}),
150147
("neg", (x,), {}),
151148
("nextafter", (x, y), {}),
152-
pytest.param("pad", (), {}, marks=pytest.mark.skip),
149+
pytest.param("pad", (), {}, marks=mark_todo),
153150
("polygamma", (1.0, xtrig), {}),
154151
("population_count", (xbit,), {}),
155152
("pow", (x, y), {}),
156153
pytest.param("random_gamma_grad", (1.0, x), {}, marks=mark_todo),
157154
("real", (xcomplex,), {}),
158155
("reciprocal", (x,), {}),
159-
pytest.param("reduce", (), {}, marks=pytest.mark.skip),
160-
pytest.param("reduce_precision", (), {}, marks=pytest.mark.skip),
161-
pytest.param("reduce_window", (), {}, marks=pytest.mark.skip),
156+
pytest.param("reduce", (), {}, marks=mark_todo),
157+
pytest.param("reduce_precision", (), {}, marks=mark_todo),
158+
pytest.param("reduce_window", (), {}, marks=mark_todo),
162159
("rem", (x, y), {}),
163160
("reshape", (x, (1, 4)), {}),
164161
("rev", (x,), {"dimensions": (0,)}),
165-
pytest.param("rng_bit_generator", (), {}, marks=pytest.mark.skip),
162+
pytest.param("rng_bit_generator", (), {}, marks=mark_todo),
166163
("rng_uniform", (0, 1, (2, 3)), {}),
167164
("round", (xround,), {}),
168165
("rsqrt", (x,), {}),
169-
pytest.param("scatter", (), {}, marks=pytest.mark.skip),
170-
pytest.param("scatter_apply", (), {}, marks=pytest.mark.skip),
171-
pytest.param("scatter_max", (), {}, marks=pytest.mark.skip),
172-
pytest.param("scatter_min", (), {}, marks=pytest.mark.skip),
173-
pytest.param("scatter_mul", (), {}, marks=pytest.mark.skip),
166+
pytest.param("scatter", (), {}, marks=mark_todo),
167+
pytest.param("scatter_apply", (), {}, marks=mark_todo),
168+
pytest.param("scatter_max", (), {}, marks=mark_todo),
169+
pytest.param("scatter_min", (), {}, marks=mark_todo),
170+
pytest.param("scatter_mul", (), {}, marks=mark_todo),
174171
("shift_left", (xbit, 1), {}),
175172
("shift_right_arithmetic", (xbit, 1), {}),
176173
("shift_right_logical", (xbit, 1), {}),
@@ -180,7 +177,7 @@
180177
("slice", (x, (0, 0), (2, 2)), {}),
181178
("slice_in_dim", (x, 0, 0, 2), {}),
182179
("sort", (x,), {}),
183-
pytest.param("sort_key_val", (), {}, marks=pytest.mark.skip),
180+
pytest.param("sort_key_val", (), {}, marks=mark_todo),
184181
("sqrt", (x,), {}),
185182
("square", (x,), {}),
186183
("sub", (x, y), {}),
@@ -190,35 +187,31 @@
190187
("transpose", (x, (1, 0)), {}),
191188
("zeros_like_array", (x,), {}),
192189
("zeta", (x, 2.0), {}),
193-
pytest.param("associative_scan", (), {}, marks=pytest.mark.skip),
190+
pytest.param("associative_scan", (), {}, marks=mark_todo),
194191
("cond", (True, lambda: x, lambda: y), {}),
195-
pytest.param("fori_loop", (), {}, marks=pytest.mark.skip),
192+
pytest.param("fori_loop", (), {}, marks=mark_todo),
196193
("map", (lambda x: x + 1, x), {}),
197-
pytest.param("scan", (), {}, marks=pytest.mark.skip),
198-
(
199-
"select",
200-
(jnp.array([[True, False], [True, False]], dtype=bool), x, y),
201-
{},
202-
),
203-
pytest.param("select_n", (), {}, marks=pytest.mark.skip),
204-
pytest.param("switch", (), {}, marks=pytest.mark.skip),
194+
pytest.param("scan", (), {}, marks=mark_todo),
195+
("select", (jnp.array([[True, False], [True, False]], dtype=bool), x, y), {}),
196+
pytest.param("select_n", (), {}, marks=mark_todo),
197+
pytest.param("switch", (), {}, marks=mark_todo),
205198
("while_loop", (lambda x: jnp.all(x < 10), lambda x: x + 1, x), {}),
206199
("stop_gradient", (x,), {}),
207-
pytest.param("custom_linear_solve", (), {}, marks=pytest.mark.skip),
208-
pytest.param("custom_root", (), {}, marks=pytest.mark.skip),
209-
pytest.param("all_gather", (), {}, marks=pytest.mark.skip),
210-
pytest.param("all_to_all", (), {}, marks=pytest.mark.skip),
211-
pytest.param("psum", (), {}, marks=pytest.mark.skip),
212-
pytest.param("psum_scatter", (), {}, marks=pytest.mark.skip),
213-
pytest.param("pmax", (), {}, marks=pytest.mark.skip),
214-
pytest.param("pmin", (), {}, marks=pytest.mark.skip),
215-
pytest.param("pmean", (), {}, marks=pytest.mark.skip),
216-
pytest.param("ppermute", (), {}, marks=pytest.mark.skip),
217-
pytest.param("pshuffle", (), {}, marks=pytest.mark.skip),
218-
pytest.param("pswapaxes", (), {}, marks=pytest.mark.skip),
219-
pytest.param("axis_index", (), {}, marks=pytest.mark.skip),
200+
pytest.param("custom_linear_solve", (), {}, marks=mark_todo),
201+
pytest.param("custom_root", (), {}, marks=mark_todo),
202+
pytest.param("all_gather", (), {}, marks=mark_todo),
203+
pytest.param("all_to_all", (), {}, marks=mark_todo),
204+
pytest.param("psum", (), {}, marks=mark_todo),
205+
pytest.param("psum_scatter", (), {}, marks=mark_todo),
206+
pytest.param("pmax", (), {}, marks=mark_todo),
207+
pytest.param("pmin", (), {}, marks=mark_todo),
208+
pytest.param("pmean", (), {}, marks=mark_todo),
209+
pytest.param("ppermute", (), {}, marks=mark_todo),
210+
pytest.param("pshuffle", (), {}, marks=mark_todo),
211+
pytest.param("pswapaxes", (), {}, marks=mark_todo),
212+
pytest.param("axis_index", (), {}, marks=mark_todo),
220213
# --- Sharding-related operators ---
221-
pytest.param("with_sharding_constraint", (), {}, marks=pytest.mark.skip),
214+
pytest.param("with_sharding_constraint", (), {}, marks=mark_todo),
222215
],
223216
)
224217
def test_lax_functions(func_name, args, kw):
@@ -252,7 +245,7 @@ def test_lax_functions(func_name, args, kw):
252245
("schur", (x1225,), {}),
253246
("svd", (x1225,), {}),
254247
("tridiagonal", (x1225,), {}),
255-
pytest.param("tridiagonal_solve", (), {}, marks=pytest.mark.skip),
248+
pytest.param("tridiagonal_solve", (), {}, marks=mark_todo),
256249
],
257250
)
258251
def test_lax_linalg_functions(func_name, args, kw):

0 commit comments

Comments
Β (0)