Skip to content

Rewrite scalar solve to division #1453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
Expand Down Expand Up @@ -908,6 +910,11 @@
return None

[input] = node.inputs

# Check if input is a (1, 1) matrix
if all(input.type.broadcastable[-2:]):
return [pt.sqrt(input)]
Comment on lines +914 to +916
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHAT AN OPTIMIZATION!!!!!!!!!!!!!!!!!!!


# Check for use of pt.diag first
if (
input.owner
Expand Down Expand Up @@ -1020,3 +1027,42 @@
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements


@register_stabilize
@register_canonicalize
@node_rewriter([Blockwise])
def scalar_solve_to_division(fgraph, node):
"""
Replace solve(a, b) with b / a if a is a (1, 1) matrix
"""

core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None

a, b = node.inputs
old_out = node.outputs[0]
if not all(a.broadcastable[-2:]):
return None

# Special handling for different types of solve
match core_op:
case SolveTriangular():
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
new_out = b / a if not core_op.unit_diagonal else b
case CholeskySolve():
new_out = b / a**2
case Solve():
new_out = b / a
case _:
raise NotImplementedError(

Check warning on line 1059 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L1058-L1059

Added lines #L1058 - L1059 were not covered by tests
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison"
)

if core_op.b_ndim == 1:
new_out = new_out.squeeze(-1)

copy_stack_trace(old_out, new_out)

return [new_out]
33 changes: 33 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,39 @@
return [exp(x)]


@register_canonicalize
@register_specialize
@node_rewriter([sqrt, sqr])
def local_sqrt_sqr(fgraph, node):
x = node.inputs[0]

if not (x.owner and isinstance(x.owner.op, Elemwise)):
return

prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op

# Case for sqrt(sqr(x)) -> |x|
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
new_out = pt_abs(x.owner.inputs[0])
old_out = node.outputs[0]

# Handle potential integer to float cast by sqr
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)

Check warning on line 422 in pytensor/tensor/rewriting/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/math.py#L422

Added line #L422 was not covered by tests
return [new_out]

# Case for sqr(sqrt(x)) -> x
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
new_out = x.owner.inputs[0]
old_out = node.outputs[0]

# Handle potential integer to float cast by sqrt
if x.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]
Comment on lines +425 to +433
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should introduce a nan switch like we do for exp(log(x)) rewrites

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we register the nan rewrite only in stabilize/specialize so if you have something like sqrt(sqr(sqrt(x)) you end up with sqrt(x) instead of sqrt(switch(x <=0, nan, x), because the sqr(sqrt(x)) rewrites can act without rush.



@register_specialize
@node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node):
Expand Down
43 changes: 35 additions & 8 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
Expand Down Expand Up @@ -920,14 +921,6 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)

# Case 2 : eye is degenerate
x = pt.scalar("x")
y = pt.eye(1) * x
z_cholesky = pt.linalg.cholesky(y)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)


def test_slogdet_specialization():
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
Expand Down Expand Up @@ -993,3 +986,37 @@ def test_slogdet_specialization():
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)


@pytest.mark.parametrize(
"Op, fn",
[
(Solve, pt.linalg.solve),
(SolveTriangular, pt.linalg.solve_triangular),
(CholeskySolve, pt.linalg.cho_solve),
],
)
def test_scalar_solve_to_division_rewrite(Op, fn):
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))

a = pt.dmatrix("a", shape=(1, 1))
b = pt.dvector("b")

if Op is CholeskySolve:
# cho_solve expects a tuple (c, lower) as the first input
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
else:
c = fn(a, b, b_ndim=1)

f = function([a, b], c, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes

assert not any(isinstance(node.op, Op) for node in nodes)

a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)

c_val = np.linalg.solve(a_val, b_val)
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
50 changes: 50 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,56 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
assert len(ops_graph) == expected_switches


class TestSqrSqrt:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including(
"local_sqrt_sqr",
).excluding("fusion")
self.rng = np.random.default_rng()

def test_sqr_sqrt(self):
# sqrt(x) ** 2 -> x
data = self.rng.normal(size=(4, 3)).astype(config.floatX)
x = pt.tensor("x", shape=(None, None))
f = function([x], sqr(sqrt(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
assert not any(
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
for node in graph
)
assert any(
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Abs)
for node in graph
)

np.testing.assert_array_equal(f(data), np.abs(data))

def test_sqrt_sqr(self):
data = self.rng.normal(size=(4, 3)).astype(config.floatX)
x = pt.tensor("x", shape=(None, None))
f = function([x], sqrt(sqr(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
assert not any(
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
for node in graph
)
np.testing.assert_array_equal(f(data), data)

def test_sqr_sqrt_integer_upcast(self):
x = ivector("x")
f = function([x], sqrt(sqr(x)), mode=self.mode)
nodes = f.maker.fgraph.toposort()
assert not any(
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
for node in nodes
)
assert f([1, 2, 3]).dtype in ["float32", "float64"]


class TestLocalSwitchSink:
def setup_method(self):
# condition values
Expand Down