|
18 | 18 | xround = jnp.array([[1.1, 2.2], [3.3, 4.4]]) |
19 | 19 | conv_kernel = jnp.array([[[[1.0, 0.0], [0.0, -1.0]]]], dtype=float) |
20 | 20 | xcomp = jnp.array([[5, 2], [7, 2]], dtype=float) |
| 21 | +xconv = jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)) |
21 | 22 |
|
22 | 23 |
|
23 | 24 | @pytest.mark.parametrize( |
|
27 | 28 | ("acos", (xtrig,), {}), |
28 | 29 | ("acosh", (x,), {}), |
29 | 30 | ("add", (x, y), {}), |
30 | | - pytest.param("after_all", (), {}, marks=pytest.mark.skip), |
| 31 | + pytest.param("after_all", (), {}, marks=mark_todo), |
31 | 32 | ("approx_max_k", (x, 2), {}), |
32 | 33 | ("approx_min_k", (x, 2), {}), |
33 | 34 | ("argmax", (x,), {"axis": 0, "index_dtype": int}), |
|
57 | 58 | ("broadcast_in_dim", (x, (1, 1, 2, 2), (2, 3)), {}), |
58 | 59 | ("broadcast_shapes", ((2, 3), (1, 3)), {}), |
59 | 60 | ("broadcast_to_rank", (x,), {"rank": 3}), |
60 | | - pytest.param("broadcasted_iota", (), {}, marks=pytest.mark.skip), |
| 61 | + pytest.param("broadcasted_iota", (), {}, marks=mark_todo), |
61 | 62 | ("cbrt", (x,), {}), |
62 | 63 | ("ceil", (xround,), {}), |
63 | 64 | ("clamp", (2.0, x, 3.0), {}), |
64 | 65 | ("clz", (xbit,), {}), |
65 | 66 | ("collapse", (x, 1), {}), |
66 | 67 | ("concatenate", ((x, y), 0), {}), |
67 | 68 | ("conj", (xcomplex,), {}), |
68 | | - ( |
69 | | - "conv", |
70 | | - (jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)), conv_kernel), |
71 | | - {"window_strides": (1, 1), "padding": "SAME"}, |
72 | | - ), |
| 69 | + ("conv", (xconv, conv_kernel), {"window_strides": (1, 1), "padding": "SAME"}), |
73 | 70 | ("convert_element_type", (x, jnp.int32), {}), |
74 | 71 | ( |
75 | 72 | "conv_dimension_numbers", |
|
78 | 75 | ), |
79 | 76 | ( |
80 | 77 | "conv_general_dilated", |
81 | | - (jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)), conv_kernel), |
| 78 | + (xconv, conv_kernel), |
82 | 79 | {"window_strides": (1, 1), "padding": "SAME"}, |
83 | 80 | ), |
84 | | - pytest.param("conv_general_dilated_local", (), {}, marks=pytest.mark.skip), |
| 81 | + pytest.param("conv_general_dilated_local", (), {}, marks=mark_todo), |
85 | 82 | ( |
86 | 83 | "conv_general_dilated_patches", |
87 | | - (jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)),), |
| 84 | + (xconv,), |
88 | 85 | {"filter_shape": (2, 2), "window_strides": (1, 1), "padding": "VALID"}, |
89 | 86 | ), |
90 | 87 | ( |
91 | 88 | "conv_transpose", |
92 | | - (jnp.arange(1, 17, dtype=float).reshape((1, 1, 4, 4)), conv_kernel), |
| 89 | + (xconv, conv_kernel), |
93 | 90 | { |
94 | 91 | "strides": (2, 2), |
95 | 92 | "padding": "SAME", |
96 | 93 | "dimension_numbers": ("NCHW", "OIHW", "NCHW"), |
97 | 94 | }, |
98 | 95 | ), |
99 | | - pytest.param("conv_with_general_padding", (), {}, marks=pytest.mark.skip), |
| 96 | + pytest.param("conv_with_general_padding", (), {}, marks=mark_todo), |
100 | 97 | ("cos", (x,), {}), |
101 | 98 | ("cosh", (x,), {}), |
102 | 99 | ("cumlogsumexp", (x,), {"axis": 0}), |
|
107 | 104 | ("digamma", (xtrig,), {}), |
108 | 105 | ("div", (x, y), {}), |
109 | 106 | ("dot", (x, y), {}), |
110 | | - pytest.param("dot_general", (), {}, marks=pytest.mark.skip), |
111 | | - pytest.param("dynamic_index_in_dim", (), {}, marks=pytest.mark.skip), |
| 107 | + pytest.param("dot_general", (), {}, marks=mark_todo), |
| 108 | + pytest.param("dynamic_index_in_dim", (), {}, marks=mark_todo), |
112 | 109 | ("dynamic_slice", (x, (0, 0), (2, 2)), {}), |
113 | | - pytest.param("dynamic_slice_in_dim", (), {}, marks=pytest.mark.skip), |
114 | | - pytest.param("dynamic_update_index_in_dim", (), {}, marks=pytest.mark.skip), |
| 110 | + pytest.param("dynamic_slice_in_dim", (), {}, marks=mark_todo), |
| 111 | + pytest.param("dynamic_update_index_in_dim", (), {}, marks=mark_todo), |
115 | 112 | ("dynamic_update_slice", (x, y, (0, 0)), {}), |
116 | 113 | ("dynamic_update_slice_in_dim", (x, y, 0, 0), {}), |
117 | 114 | ("eq", (x, x), {}), |
|
126 | 123 | ("floor", (xround,), {}), |
127 | 124 | ("full", ((2, 2), 1.0), {}), |
128 | 125 | ("full_like", (x, 1.0), {}), |
129 | | - pytest.param("gather", (), {}, marks=pytest.mark.skip), |
| 126 | + pytest.param("gather", (), {}, marks=mark_todo), |
130 | 127 | ("ge", (x, xcomp), {}), |
131 | 128 | ("gt", (x, xcomp), {}), |
132 | 129 | ("igamma", (1.0, xtrig), {}), |
133 | 130 | ("igammac", (1.0, xtrig), {}), |
134 | 131 | ("imag", (xcomplex,), {}), |
135 | 132 | ("index_in_dim", (x, 0, 0), {}), |
136 | | - pytest.param("index_take", (), {}, marks=pytest.mark.skip), |
| 133 | + pytest.param("index_take", (), {}, marks=mark_todo), |
137 | 134 | ("integer_pow", (x, 2), {}), |
138 | | - pytest.param("iota", (), {}, marks=pytest.mark.skip), |
| 135 | + pytest.param("iota", (), {}, marks=mark_todo), |
139 | 136 | ("is_finite", (x,), {}), |
140 | 137 | ("le", (x, xcomp), {}), |
141 | 138 | ("lgamma", (x,), {}), |
|
149 | 146 | ("ne", (x, xcomp), {}), |
150 | 147 | ("neg", (x,), {}), |
151 | 148 | ("nextafter", (x, y), {}), |
152 | | - pytest.param("pad", (), {}, marks=pytest.mark.skip), |
| 149 | + pytest.param("pad", (), {}, marks=mark_todo), |
153 | 150 | ("polygamma", (1.0, xtrig), {}), |
154 | 151 | ("population_count", (xbit,), {}), |
155 | 152 | ("pow", (x, y), {}), |
156 | 153 | pytest.param("random_gamma_grad", (1.0, x), {}, marks=mark_todo), |
157 | 154 | ("real", (xcomplex,), {}), |
158 | 155 | ("reciprocal", (x,), {}), |
159 | | - pytest.param("reduce", (), {}, marks=pytest.mark.skip), |
160 | | - pytest.param("reduce_precision", (), {}, marks=pytest.mark.skip), |
161 | | - pytest.param("reduce_window", (), {}, marks=pytest.mark.skip), |
| 156 | + pytest.param("reduce", (), {}, marks=mark_todo), |
| 157 | + pytest.param("reduce_precision", (), {}, marks=mark_todo), |
| 158 | + pytest.param("reduce_window", (), {}, marks=mark_todo), |
162 | 159 | ("rem", (x, y), {}), |
163 | 160 | ("reshape", (x, (1, 4)), {}), |
164 | 161 | ("rev", (x,), {"dimensions": (0,)}), |
165 | | - pytest.param("rng_bit_generator", (), {}, marks=pytest.mark.skip), |
| 162 | + pytest.param("rng_bit_generator", (), {}, marks=mark_todo), |
166 | 163 | ("rng_uniform", (0, 1, (2, 3)), {}), |
167 | 164 | ("round", (xround,), {}), |
168 | 165 | ("rsqrt", (x,), {}), |
169 | | - pytest.param("scatter", (), {}, marks=pytest.mark.skip), |
170 | | - pytest.param("scatter_apply", (), {}, marks=pytest.mark.skip), |
171 | | - pytest.param("scatter_max", (), {}, marks=pytest.mark.skip), |
172 | | - pytest.param("scatter_min", (), {}, marks=pytest.mark.skip), |
173 | | - pytest.param("scatter_mul", (), {}, marks=pytest.mark.skip), |
| 166 | + pytest.param("scatter", (), {}, marks=mark_todo), |
| 167 | + pytest.param("scatter_apply", (), {}, marks=mark_todo), |
| 168 | + pytest.param("scatter_max", (), {}, marks=mark_todo), |
| 169 | + pytest.param("scatter_min", (), {}, marks=mark_todo), |
| 170 | + pytest.param("scatter_mul", (), {}, marks=mark_todo), |
174 | 171 | ("shift_left", (xbit, 1), {}), |
175 | 172 | ("shift_right_arithmetic", (xbit, 1), {}), |
176 | 173 | ("shift_right_logical", (xbit, 1), {}), |
|
180 | 177 | ("slice", (x, (0, 0), (2, 2)), {}), |
181 | 178 | ("slice_in_dim", (x, 0, 0, 2), {}), |
182 | 179 | ("sort", (x,), {}), |
183 | | - pytest.param("sort_key_val", (), {}, marks=pytest.mark.skip), |
| 180 | + pytest.param("sort_key_val", (), {}, marks=mark_todo), |
184 | 181 | ("sqrt", (x,), {}), |
185 | 182 | ("square", (x,), {}), |
186 | 183 | ("sub", (x, y), {}), |
|
190 | 187 | ("transpose", (x, (1, 0)), {}), |
191 | 188 | ("zeros_like_array", (x,), {}), |
192 | 189 | ("zeta", (x, 2.0), {}), |
193 | | - pytest.param("associative_scan", (), {}, marks=pytest.mark.skip), |
| 190 | + pytest.param("associative_scan", (), {}, marks=mark_todo), |
194 | 191 | ("cond", (True, lambda: x, lambda: y), {}), |
195 | | - pytest.param("fori_loop", (), {}, marks=pytest.mark.skip), |
| 192 | + pytest.param("fori_loop", (), {}, marks=mark_todo), |
196 | 193 | ("map", (lambda x: x + 1, x), {}), |
197 | | - pytest.param("scan", (), {}, marks=pytest.mark.skip), |
198 | | - ( |
199 | | - "select", |
200 | | - (jnp.array([[True, False], [True, False]], dtype=bool), x, y), |
201 | | - {}, |
202 | | - ), |
203 | | - pytest.param("select_n", (), {}, marks=pytest.mark.skip), |
204 | | - pytest.param("switch", (), {}, marks=pytest.mark.skip), |
| 194 | + pytest.param("scan", (), {}, marks=mark_todo), |
| 195 | + ("select", (jnp.array([[True, False], [True, False]], dtype=bool), x, y), {}), |
| 196 | + pytest.param("select_n", (), {}, marks=mark_todo), |
| 197 | + pytest.param("switch", (), {}, marks=mark_todo), |
205 | 198 | ("while_loop", (lambda x: jnp.all(x < 10), lambda x: x + 1, x), {}), |
206 | 199 | ("stop_gradient", (x,), {}), |
207 | | - pytest.param("custom_linear_solve", (), {}, marks=pytest.mark.skip), |
208 | | - pytest.param("custom_root", (), {}, marks=pytest.mark.skip), |
209 | | - pytest.param("all_gather", (), {}, marks=pytest.mark.skip), |
210 | | - pytest.param("all_to_all", (), {}, marks=pytest.mark.skip), |
211 | | - pytest.param("psum", (), {}, marks=pytest.mark.skip), |
212 | | - pytest.param("psum_scatter", (), {}, marks=pytest.mark.skip), |
213 | | - pytest.param("pmax", (), {}, marks=pytest.mark.skip), |
214 | | - pytest.param("pmin", (), {}, marks=pytest.mark.skip), |
215 | | - pytest.param("pmean", (), {}, marks=pytest.mark.skip), |
216 | | - pytest.param("ppermute", (), {}, marks=pytest.mark.skip), |
217 | | - pytest.param("pshuffle", (), {}, marks=pytest.mark.skip), |
218 | | - pytest.param("pswapaxes", (), {}, marks=pytest.mark.skip), |
219 | | - pytest.param("axis_index", (), {}, marks=pytest.mark.skip), |
| 200 | + pytest.param("custom_linear_solve", (), {}, marks=mark_todo), |
| 201 | + pytest.param("custom_root", (), {}, marks=mark_todo), |
| 202 | + pytest.param("all_gather", (), {}, marks=mark_todo), |
| 203 | + pytest.param("all_to_all", (), {}, marks=mark_todo), |
| 204 | + pytest.param("psum", (), {}, marks=mark_todo), |
| 205 | + pytest.param("psum_scatter", (), {}, marks=mark_todo), |
| 206 | + pytest.param("pmax", (), {}, marks=mark_todo), |
| 207 | + pytest.param("pmin", (), {}, marks=mark_todo), |
| 208 | + pytest.param("pmean", (), {}, marks=mark_todo), |
| 209 | + pytest.param("ppermute", (), {}, marks=mark_todo), |
| 210 | + pytest.param("pshuffle", (), {}, marks=mark_todo), |
| 211 | + pytest.param("pswapaxes", (), {}, marks=mark_todo), |
| 212 | + pytest.param("axis_index", (), {}, marks=mark_todo), |
220 | 213 | # --- Sharding-related operators --- |
221 | | - pytest.param("with_sharding_constraint", (), {}, marks=pytest.mark.skip), |
| 214 | + pytest.param("with_sharding_constraint", (), {}, marks=mark_todo), |
222 | 215 | ], |
223 | 216 | ) |
224 | 217 | def test_lax_functions(func_name, args, kw): |
@@ -252,7 +245,7 @@ def test_lax_functions(func_name, args, kw): |
252 | 245 | ("schur", (x1225,), {}), |
253 | 246 | ("svd", (x1225,), {}), |
254 | 247 | ("tridiagonal", (x1225,), {}), |
255 | | - pytest.param("tridiagonal_solve", (), {}, marks=pytest.mark.skip), |
| 248 | + pytest.param("tridiagonal_solve", (), {}, marks=mark_todo), |
256 | 249 | ], |
257 | 250 | ) |
258 | 251 | def test_lax_linalg_functions(func_name, args, kw): |
|
0 commit comments