64
64
log ,
65
65
log1mexp ,
66
66
log1p ,
67
+ log1pexp ,
67
68
makeKeepDims ,
68
69
maximum ,
69
70
mul ,
@@ -2999,12 +3000,6 @@ def _is_1(expr):
2999
3000
tracks = [sigmoid ],
3000
3001
get_nodes = get_clients_at_depth2 ,
3001
3002
)
3002
- log1pexp_to_softplus = PatternNodeRewriter (
3003
- (log1p , (exp , "x" )),
3004
- (softplus , "x" ),
3005
- values_eq_approx = values_eq_approx_remove_inf ,
3006
- allow_multiple_clients = True ,
3007
- )
3008
3003
log1p_neg_sigmoid = PatternNodeRewriter (
3009
3004
(log1p , (neg , (sigmoid , "x" ))),
3010
3005
(neg , (softplus , "x" )),
@@ -3016,7 +3011,6 @@ def _is_1(expr):
3016
3011
3017
3012
register_stabilize (logsigm_to_softplus , name = "logsigm_to_softplus" )
3018
3013
register_stabilize (log1msigm_to_softplus , name = "log1msigm_to_softplus" )
3019
- register_stabilize (log1pexp_to_softplus , name = "log1pexp_to_softplus" )
3020
3014
register_stabilize (log1p_neg_sigmoid , name = "log1p_neg_sigmoid" )
3021
3015
register_specialize (log1p_neg_sigmoid , name = "log1p_neg_sigmoid" )
3022
3016
@@ -3584,8 +3578,10 @@ def local_reciprocal_1_plus_exp(fgraph, node):
3584
3578
3585
3579
@register_stabilize
3586
3580
@node_rewriter ([log1p ])
3587
- def log1pmexp_to_log1mexp (fgraph , node ):
3588
- """``log1p(-exp(x)) -> log1mexp(x)``
3581
+ def local_log1p_plusminus_exp (fgraph , node ):
3582
+ """Transforms log1p of ±exp(x) into log1pexp (aka softplus) / log1mexp
3583
+ ``log1p(exp(x)) -> log1pexp(x)``
3584
+ ``log1p(-exp(x)) -> log1mexp(x)``
3589
3585
where "-" can be "neg" or any other expression detected by "is_neg"
3590
3586
"""
3591
3587
(log1p_arg ,) = node .inputs
@@ -3595,7 +3591,7 @@ def log1pmexp_to_log1mexp(fgraph, node):
3595
3591
if exp_neg :
3596
3592
return [log1mexp (exp_arg )]
3597
3593
else :
3598
- return # We could return [log1pexp(exp_arg)] here but that would conflict with log1pexp_to_softplus
3594
+ return [log1pexp (exp_arg )] # aka softplus
3599
3595
3600
3596
3601
3597
@register_stabilize
0 commit comments