Skip to content

Commit 92a29d2

Browse files
author
Luca Citi
committed
Absorbed the rewrite log1pexp_to_softplus into the new rewrite for log1mexp
1 parent 8cd1b52 commit 92a29d2

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
log,
6565
log1mexp,
6666
log1p,
67+
log1pexp,
6768
makeKeepDims,
6869
maximum,
6970
mul,
@@ -2999,12 +3000,6 @@ def _is_1(expr):
29993000
tracks=[sigmoid],
30003001
get_nodes=get_clients_at_depth2,
30013002
)
3002-
log1pexp_to_softplus = PatternNodeRewriter(
3003-
(log1p, (exp, "x")),
3004-
(softplus, "x"),
3005-
values_eq_approx=values_eq_approx_remove_inf,
3006-
allow_multiple_clients=True,
3007-
)
30083003
log1p_neg_sigmoid = PatternNodeRewriter(
30093004
(log1p, (neg, (sigmoid, "x"))),
30103005
(neg, (softplus, "x")),
@@ -3016,7 +3011,6 @@ def _is_1(expr):
30163011

30173012
register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus")
30183013
register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus")
3019-
register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus")
30203014
register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
30213015
register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
30223016

@@ -3584,8 +3578,10 @@ def local_reciprocal_1_plus_exp(fgraph, node):
35843578

35853579
@register_stabilize
35863580
@node_rewriter([log1p])
3587-
def log1pmexp_to_log1mexp(fgraph, node):
3588-
"""``log1p(-exp(x)) -> log1mexp(x)``
3581+
def local_log1p_plusminus_exp(fgraph, node):
3582+
"""Transforms log1p of ±exp(x) into log1pexp (aka softplus) / log1mexp
3583+
``log1p(exp(x)) -> log1pexp(x)``
3584+
``log1p(-exp(x)) -> log1mexp(x)``
35893585
where "-" can be "neg" or any other expression detected by "is_neg"
35903586
"""
35913587
(log1p_arg,) = node.inputs
@@ -3595,7 +3591,7 @@ def log1pmexp_to_log1mexp(fgraph, node):
35953591
if exp_neg:
35963592
return [log1mexp(exp_arg)]
35973593
else:
3598-
return # We could return [log1pexp(exp_arg)] here but that would conflict with log1pexp_to_softplus
3594+
return [log1pexp(exp_arg)] # aka softplus
35993595

36003596

36013597
@register_stabilize

0 commit comments

Comments
 (0)