Closed
Description
Describe the issue:
While investigating PR #1452 I run into the following issue.
The rewrite log1pmexp_to_log1mexp
(part of the "stabilize" phase) is generally not applied because:
log1pmexp_to_log1mexp
tries to match(log1p, (neg, (exp, "x")))
, however- during the preceding "canonicalize" phase the
neg
is converted to-1.0*
so it fails to the match above
This happens both when compiling and when applying rewrite_graph
.
See the MRE below, which gives:
Log1p [id A]
└─ Neg [id B]
└─ Exp [id C]
└─ x [id D]
====================
rewriting: rewrite local_neg_to_mul replaces Neg.0 of Neg(Exp.0) with Mul.0 of Mul(-1.0, Exp.0)
rewriting: rewrite local_mul_specialize replaces Mul.0 of Mul(-1.0, Exp.0) with Neg.0 of Neg(Exp.0)
====================
Log1p [id A]
└─ Neg [id B]
└─ Exp [id C]
└─ x [id D]
If one removes the "canonicalize"
step from rewrite_graph
, then the rewrite is correctly applied:
Log1p [id A]
└─ Neg [id B]
└─ Exp [id C]
└─ x [id D]
====================
rewriting: rewrite e(Log1p, e(Neg, e(Exp, ~x))) -> e(Scalar_log1mexp, ~x) replaces Log1p.0 of Log1p(Neg.0) with Scalar_log1mexp.0 of Scalar_log1mexp(x)
====================
Scalar_log1mexp [id A]
└─ x [id B]
Reproducable code example:
import pytensor
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
x = pt.scalar("x")
out = pt.log1p(-pt.exp(x))
pytensor.dprint(out)
print('='*20)
with pytensor.config.change_flags(optimizer_verbose = True):
#fn = pytensor.function([x], out, mode="FAST_RUN")
fn = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
print('='*20)
pytensor.dprint(fn);
Error message:
PyTensor version information:
PyTensor 2.31.3
Context for the issue:
No response