From d2c749590d57cc23376e34260deb64c753b68a78 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 7 Jun 2025 21:44:55 +0800 Subject: [PATCH 1/2] Rewrite scalar solve to division --- pytensor/tensor/rewriting/linalg.py | 46 +++++++++++++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 43 ++++++++++++++++++++----- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index af42bee236..ecdbe6e7ed 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -47,8 +47,10 @@ from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, + CholeskySolve, Solve, SolveBase, + SolveTriangular, _bilinear_solve_discrete_lyapunov, block_diag, cholesky, @@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): return None [input] = node.inputs + + # Check if input is a (1, 1) matrix + if all(input.type.broadcastable[-2:]): + return [pt.sqrt(input)] + # Check for use of pt.diag first if ( input.owner @@ -1020,3 +1027,42 @@ def slogdet_specialization(fgraph, node): k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() } return replacements + + +@register_stabilize +@register_canonicalize +@node_rewriter([Blockwise]) +def scalar_solve_to_division(fgraph, node): + """ + Replace solve(a, b) with b / a if a is a (1, 1) matrix + """ + + core_op = node.op.core_op + if not isinstance(core_op, SolveBase): + return None + + a, b = node.inputs + old_out = node.outputs[0] + if not all(a.broadcastable[-2:]): + return None + + # Special handling for different types of solve + match core_op: + case SolveTriangular(): + # Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1 + new_out = b / a if not core_op.unit_diagonal else b + case CholeskySolve(): + new_out = b / a**2 + case Solve(): + new_out = b / a + case _: + raise NotImplementedError( + f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison" + ) + + if core_op.b_ndim == 1: + new_out = new_out.squeeze(-1) + + copy_stack_trace(old_out, new_out) + + return [new_out] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 50e48ce95d..539951c1d6 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -29,6 +29,7 @@ from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, + CholeskySolve, Solve, SolveBase, SolveTriangular, @@ -920,14 +921,6 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Cholesky) for node in nodes) - # Case 2 : eye is degenerate - x = pt.scalar("x") - y = pt.eye(1) * x - z_cholesky = pt.linalg.cholesky(y) - f_rewritten = function([x], z_cholesky, mode="FAST_RUN") - nodes = f_rewritten.maker.fgraph.apply_nodes - assert any(isinstance(node.op, Cholesky) for node in nodes) - def test_slogdet_specialization(): x, a = pt.dmatrix("x"), np.random.rand(20, 20) @@ -993,3 +986,37 @@ def test_slogdet_specialization(): f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes) + + +@pytest.mark.parametrize( + "Op, fn", + [ + (Solve, pt.linalg.solve), + (SolveTriangular, pt.linalg.solve_triangular), + (CholeskySolve, pt.linalg.cho_solve), + ], +) +def test_scalar_solve_to_division_rewrite(Op, fn): + rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite"))) + + a = pt.dmatrix("a", shape=(1, 1)) + b = pt.dvector("b") + + if Op is CholeskySolve: + # cho_solve expects a tuple (c, lower) as the first input + c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1) + else: + c = fn(a, b, b_ndim=1) + + f = function([a, b], c, mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + + assert not any(isinstance(node.op, Op) for node in nodes) + + a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX) + b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX) + + c_val = np.linalg.solve(a_val, b_val) + np.testing.assert_allclose( + f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5 + ) From 433cc6cb68872d7f1a8873dfe1046d844cf06699 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 23 Jun 2025 12:35:23 +0200 Subject: [PATCH 2/2] Rewrite `sqr(sqrt(x)) -> |x|` and `sqrt(sqr(x)) -> x` --- pytensor/tensor/rewriting/math.py | 31 +++++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 39 +++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index aef363655e..d126502bde 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -400,6 +400,37 @@ def local_exp_log(fgraph, node): return [exp(x)] +@register_canonicalize +@register_specialize +@node_rewriter([sqrt, sqr]) +def local_sqrt_sqr(fgraph, node): + x = node.inputs[0] + + if not (x.owner and isinstance(x.owner.op, Elemwise)): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + # Case for sqrt(sqr(x)) -> |x| + if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr): + new_out = pt_abs(x.owner.inputs[0]) + old_out = node.outputs[0] + + # Handle potential integer to float cast by sqr + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] + + # Case for sqr(sqrt(x)) -> x + if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) + + return [new_out] + + @register_specialize @node_rewriter([exp, expm1]) def local_exp_log_nan_switch(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index c4999fcd33..3699a3fcff 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2031,6 +2031,45 @@ def test_exp_log_nested(self, nested_expression, expected_switches): assert len(ops_graph) == expected_switches +class TestSqrSqrt: + def setup_method(self): + mode = get_default_mode() + self.mode = mode.including( + "local_sqrt_sqr", + ).excluding("fusion") + self.rng = np.random.default_rng() + + def test_sqr_sqrt(self): + # sqrt(x) ** 2 -> x + x = pt.tensor("x", shape=(None, None)) + out = sqr(sqrt(x)) + out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"]) + + assert equal_computations([out], [pt_abs(x)]) + + def test_sqrt_sqr(self): + x = pt.tensor("x", shape=(None, None)) + out = sqrt(sqr(x)) + out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"]) + + expected = switch( + ge(x, np.zeros((1, 1), dtype="int8")), + x, + np.full((1, 1), np.nan, dtype=x.type.dtype), + ) + + assert equal_computations([out], [expected]) + + def test_sqr_sqrt_integer_upcast(self): + x = ivector("x") + out = sqr(sqrt(x)) + dtype = out.type.dtype + out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"]) + + expected = pt.cast(pt_abs(x), dtype=dtype) + assert equal_computations([out], [expected]) + + class TestLocalSwitchSink: def setup_method(self): # condition values