Skip to content

Commit a3f2543

Browse files
Rewrite sqr(sqrt(x)) -> |x| and sqrt(sqr(x)) -> x
1 parent a7e44ba commit a3f2543

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,39 @@ def local_exp_log(fgraph, node):
400400
return [exp(x)]
401401

402402

403+
@register_canonicalize
404+
@register_specialize
405+
@node_rewriter([sqrt, sqr])
406+
def local_sqrt_sqr(fgraph, node):
407+
x = node.inputs[0]
408+
409+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
410+
return
411+
412+
prev_op = x.owner.op.scalar_op
413+
node_op = node.op.scalar_op
414+
415+
# Case for sqrt(sqr(x)) -> |x|
416+
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
417+
new_out = pt_abs(x.owner.inputs[0])
418+
old_out = node.outputs[0]
419+
420+
# Handle potential integer to float cast by sqr
421+
if new_out.dtype != old_out.dtype:
422+
new_out = cast(new_out, old_out.dtype)
423+
return [new_out]
424+
425+
# Case for sqr(sqrt(x)) -> x
426+
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
427+
new_out = x.owner.inputs[0]
428+
old_out = node.outputs[0]
429+
430+
# Handle potential integer to float cast by sqrt
431+
if x.dtype != old_out.dtype:
432+
new_out = cast(new_out, old_out.dtype)
433+
return [new_out]
434+
435+
403436
@register_specialize
404437
@node_rewriter([exp, expm1])
405438
def local_exp_log_nan_switch(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,56 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
20312031
assert len(ops_graph) == expected_switches
20322032

20332033

2034+
class TestSqrSqrt:
2035+
def setup_method(self):
2036+
mode = get_default_mode()
2037+
self.mode = mode.including(
2038+
"local_sqrt_sqr",
2039+
).excluding("fusion")
2040+
self.rng = np.random.default_rng()
2041+
2042+
def test_sqr_sqrt(self):
2043+
# sqrt(x) ** 2 -> x
2044+
data = self.rng.normal(size=(4, 3)).astype(config.floatX)
2045+
x = pt.tensor("x", shape=(None, None))
2046+
f = function([x], sqr(sqrt(x)), mode=self.mode)
2047+
graph = f.maker.fgraph.toposort()
2048+
assert not any(
2049+
isinstance(node.op, Elemwise)
2050+
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
2051+
for node in graph
2052+
)
2053+
assert any(
2054+
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Abs)
2055+
for node in graph
2056+
)
2057+
2058+
np.testing.assert_array_equal(f(data), np.abs(data))
2059+
2060+
def test_sqrt_sqr(self):
2061+
data = self.rng.normal(size=(4, 3)).astype(config.floatX)
2062+
x = pt.tensor("x", shape=(None, None))
2063+
f = function([x], sqrt(sqr(x)), mode=self.mode)
2064+
graph = f.maker.fgraph.toposort()
2065+
assert not any(
2066+
isinstance(node.op, Elemwise)
2067+
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
2068+
for node in graph
2069+
)
2070+
np.testing.assert_array_equal(f(data), data)
2071+
2072+
def test_sqr_sqrt_integer_upcast(self):
2073+
x = ivector("x")
2074+
f = function([x], sqrt(sqr(x)), mode=self.mode)
2075+
nodes = f.maker.fgraph.toposort()
2076+
assert not any(
2077+
isinstance(node.op, Elemwise)
2078+
and isinstance(node.op.scalar_op, ps.Sqr | ps.Sqrt)
2079+
for node in nodes
2080+
)
2081+
assert f([1, 2, 3]).dtype == config.floatX
2082+
2083+
20342084
class TestLocalSwitchSink:
20352085
def setup_method(self):
20362086
# condition values

0 commit comments

Comments
 (0)