Skip to content

Commit f8b4870

Browse files
Rewrite sqr(sqrt(x)) -> |x| and sqrt(sqr(x)) -> x
1 parent d2c7495 commit f8b4870

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,37 @@ def local_exp_log(fgraph, node):
400400
return [exp(x)]
401401

402402

403+
@register_canonicalize
404+
@register_specialize
405+
@node_rewriter([sqrt, sqr])
406+
def local_sqrt_sqr(fgraph, node):
407+
x = node.inputs[0]
408+
409+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
410+
return
411+
412+
prev_op = x.owner.op.scalar_op
413+
node_op = node.op.scalar_op
414+
415+
# Case for sqrt(sqr(x)) -> |x|
416+
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
417+
new_out = pt_abs(x.owner.inputs[0])
418+
old_out = node.outputs[0]
419+
420+
# Handle potential integer to float cast by sqr
421+
if new_out.dtype != old_out.dtype:
422+
new_out = cast(new_out, old_out.dtype)
423+
return [new_out]
424+
425+
# Case for sqr(sqrt(x)) -> x
426+
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
427+
x = x.owner.inputs[0]
428+
old_out = node.outputs[0]
429+
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
430+
431+
return [new_out]
432+
433+
403434
@register_specialize
404435
@node_rewriter([exp, expm1])
405436
def local_exp_log_nan_switch(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,66 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
20312031
assert len(ops_graph) == expected_switches
20322032

20332033

2034+
class TestSqrSqrt:
2035+
def setup_method(self):
2036+
mode = get_default_mode()
2037+
self.mode = mode.including(
2038+
"local_sqrt_sqr",
2039+
).excluding("fusion")
2040+
self.rng = np.random.default_rng()
2041+
2042+
def test_sqr_sqrt(self):
2043+
# sqrt(x) ** 2 -> x
2044+
x = pt.tensor("x", shape=(None, None))
2045+
out = sqr(sqrt(x))
2046+
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
2047+
2048+
fg = FunctionGraph(inputs=[x], outputs=[out])
2049+
2050+
assert not any(
2051+
isinstance(node.op, Elemwise)
2052+
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
2053+
for node in fg.toposort()
2054+
)
2055+
assert any(
2056+
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Abs)
2057+
for node in fg.toposort()
2058+
)
2059+
2060+
assert equal_computations([out], [pt_abs(x)])
2061+
2062+
def test_sqrt_sqr(self):
2063+
x = pt.tensor("x", shape=(None, None))
2064+
out = sqrt(sqr(x))
2065+
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
2066+
2067+
fg = FunctionGraph(inputs=[x], outputs=[out])
2068+
2069+
assert not any(
2070+
isinstance(node.op, Elemwise)
2071+
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
2072+
for node in fg.toposort()
2073+
)
2074+
expected = switch(
2075+
ge(x, np.zeros((1, 1), dtype="int8")),
2076+
x,
2077+
np.full((1, 1), np.nan, dtype=x.type.dtype),
2078+
)
2079+
2080+
assert equal_computations([out], [expected])
2081+
2082+
def test_sqr_sqrt_integer_upcast(self):
2083+
x = ivector("x")
2084+
f = function([x], sqrt(sqr(x)), mode=self.mode)
2085+
nodes = f.maker.fgraph.toposort()
2086+
assert not any(
2087+
isinstance(node.op, Elemwise)
2088+
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
2089+
for node in nodes
2090+
)
2091+
assert f([1, 2, 3]).dtype in ["float32", "float64"]
2092+
2093+
20342094
class TestLocalSwitchSink:
20352095
def setup_method(self):
20362096
# condition values

0 commit comments

Comments
 (0)