@@ -251,3 +251,53 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
251
251
assert fn_opt (A_valid , b1_valid * np .nan , b2_valid )
252
252
with pytest .raises (ValueError , match = "array must not contain infs or NaNs" ):
253
253
assert fn_opt (A_valid * np .nan , b1_valid , b2_valid )
254
+
255
+
256
+ @pytest .mark .parametrize (
257
+ "lower_first" , [True , False ], ids = ["lower_first" , "upper_first" ]
258
+ )
259
+ def test_cho_solve_handles_lower_flags (lower_first ):
260
+ rewrite_name = reuse_decomposition_multiple_solves .__name__
261
+ A = tensor ("A" , shape = (5 , None ))
262
+ b = tensor ("b" , shape = (5 ,))
263
+
264
+ x1 = solve (A , b , assume_a = "pos" , lower = lower_first , check_finite = False )
265
+ x2 = solve (A .mT , b , assume_a = "pos" , lower = not lower_first , check_finite = False )
266
+
267
+ dx1_dA = grad (x1 .sum (), A )
268
+ dx2_dA = grad (x2 .sum (), A )
269
+
270
+ fn = function ([A , b ], [x1 , dx1_dA , x2 , dx2_dA ])
271
+ fn_no_rewrite = function (
272
+ [A , b ],
273
+ [x1 , dx1_dA , x2 , dx2_dA ],
274
+ mode = get_default_mode ().excluding (rewrite_name ),
275
+ )
276
+
277
+ rng = np .random .default_rng ()
278
+ L_values = rng .normal (size = (5 , 5 )).astype (config .floatX )
279
+ A_values = L_values @ L_values .T # Ensure A is positive definite
280
+
281
+ if lower_first :
282
+ A_values [np .triu_indices (5 , k = 1 )] = np .nan
283
+ else :
284
+ A_values [np .tril_indices (5 , k = - 1 )] = np .nan
285
+
286
+ b_values = rng .normal (size = (5 ,)).astype (config .floatX )
287
+
288
+ # This computation should not raise an error, and none of them should be NaN
289
+ res = fn (A_values , b_values )
290
+ expected_res = fn_no_rewrite (A_values , b_values )
291
+
292
+ for x , expected_x in zip (res , expected_res ):
293
+ assert np .isfinite (x ).all ()
294
+ np .testing .assert_allclose (
295
+ x ,
296
+ expected_x ,
297
+ atol = 1e-6 if config .floatX == "float64" else 1e-3 ,
298
+ rtol = 1e-6 if config .floatX == "float64" else 1e-3 ,
299
+ )
300
+
301
+ # If we put the NaN in the wrong place, it should raise an error
302
+ with pytest .raises (np .linalg .LinAlgError ):
303
+ fn (A_values .T , b_values )
0 commit comments