Skip to content

Commit 6d15f95

Browse files
Rewrite sqr(sqrt(x)) -> |x| and sqrt(sqr(x)) -> x
1 parent d2c7495 commit 6d15f95

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,37 @@ def local_exp_log(fgraph, node):
400400
return [exp(x)]
401401

402402

403+
@register_canonicalize
404+
@register_specialize
405+
@node_rewriter([sqrt, sqr])
406+
def local_sqrt_sqr(fgraph, node):
407+
x = node.inputs[0]
408+
409+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
410+
return
411+
412+
prev_op = x.owner.op.scalar_op
413+
node_op = node.op.scalar_op
414+
415+
# Case for sqrt(sqr(x)) -> |x|
416+
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
417+
new_out = pt_abs(x.owner.inputs[0])
418+
old_out = node.outputs[0]
419+
420+
# Handle potential integer to float cast by sqr
421+
if new_out.dtype != old_out.dtype:
422+
new_out = cast(new_out, old_out.dtype)
423+
return [new_out]
424+
425+
# Case for sqr(sqrt(x)) -> x
426+
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
427+
x = x.owner.inputs[0]
428+
old_out = node.outputs[0]
429+
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
430+
431+
return [new_out]
432+
433+
403434
@register_specialize
404435
@node_rewriter([exp, expm1])
405436
def local_exp_log_nan_switch(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,44 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
20312031
assert len(ops_graph) == expected_switches
20322032

20332033

2034+
class TestSqrSqrt:
2035+
def setup_method(self):
2036+
mode = get_default_mode()
2037+
self.mode = mode.including(
2038+
"local_sqrt_sqr",
2039+
).excluding("fusion")
2040+
self.rng = np.random.default_rng()
2041+
2042+
def test_sqr_sqrt(self):
2043+
# sqrt(x) ** 2 -> x
2044+
x = pt.tensor("x", shape=(None, None))
2045+
out = sqr(sqrt(x))
2046+
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
2047+
2048+
assert equal_computations([out], [pt_abs(x)])
2049+
2050+
def test_sqrt_sqr(self):
2051+
x = pt.tensor("x", shape=(None, None))
2052+
out = sqrt(sqr(x))
2053+
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
2054+
2055+
expected = switch(
2056+
ge(x, np.zeros((1, 1), dtype="int8")),
2057+
x,
2058+
np.full((1, 1), np.nan, dtype=x.type.dtype),
2059+
)
2060+
2061+
assert equal_computations([out], [expected])
2062+
2063+
def test_sqr_sqrt_integer_upcast(self):
2064+
x = ivector("x")
2065+
out = sqr(sqrt(x))
2066+
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
2067+
2068+
expected = pt.cast(pt_abs(x), dtype=config.floatX)
2069+
assert equal_computations([out], [expected])
2070+
2071+
20342072
class TestLocalSwitchSink:
20352073
def setup_method(self):
20362074
# condition values

0 commit comments

Comments
 (0)