30
30
cast ,
31
31
constant ,
32
32
extract_constant ,
33
- fill ,
34
33
get_underlying_scalar_constant_value ,
35
34
ones_like ,
36
35
switch ,
@@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node):
2041
2040
@register_specialize
2042
2041
@node_rewriter ([at_pow ])
2043
2042
def local_pow_specialize (fgraph , node ):
2044
- # here, we are past the point of canonicalization, so we don't want
2045
- # to put in un-necessary fills.
2046
2043
if node .op == at_pow :
2047
2044
# the idea here is that we have pow(x, y)
2048
2045
odtype = node .outputs [0 ].dtype
@@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node):
2057
2054
if np .all (y == 1 ):
2058
2055
rval = [xsym ]
2059
2056
if np .all (y == 0 ):
2060
- rval = [fill ( xsym , np . asarray ( 1 , dtype = odtype ) )]
2057
+ rval = [alloc_like ( 1 , xsym , fgraph )]
2061
2058
if np .all (y == 0.5 ):
2062
2059
rval = [sqrt (xsym )]
2063
2060
if np .all (y == - 0.5 ):
@@ -2159,9 +2156,7 @@ def local_mul_specialize(fgraph, node):
2159
2156
mul(-1, x, y) -/-> neg(mul(x, y))
2160
2157
2161
2158
"""
2162
- # here, we are past the point of canonicalization, so we don't
2163
- # want to put in un-necessary fills.
2164
- #
2159
+
2165
2160
# at this point [post canonicalize], mul() may have many inputs.
2166
2161
if node .op == mul :
2167
2162
# the idea here is that we have pow(x, y)
@@ -2222,16 +2217,7 @@ def local_mul_specialize(fgraph, node):
2222
2217
2223
2218
@register_specialize
2224
2219
@node_rewriter ([add ])
2225
- def local_add_specialize (fgraph , node ):
2226
- """Remove zeros from ``add``s.
2227
-
2228
- TODO: This should be a canonicalization, no?
2229
- """
2230
- # here, we are past the point of canonicalization, so we don't want
2231
- # to put in un-necessary fills.
2232
- if node .op != add :
2233
- return False
2234
-
2220
+ def local_add_remove_zeros (fgraph , node ):
2235
2221
new_inputs = []
2236
2222
for inp in node .inputs :
2237
2223
try :
@@ -2254,12 +2240,12 @@ def local_add_specialize(fgraph, node):
2254
2240
# Reuse call to constant for cache()
2255
2241
cst = constant (np .zeros ((1 ,) * ndim , dtype = dtype ))
2256
2242
assert cst .type .broadcastable == (True ,) * ndim
2257
- return [broadcast_arrays (cst , * node . inputs )[ 0 ] ]
2243
+ return [alloc_like (cst , node_output , fgraph ) ]
2258
2244
2259
2245
if len (new_inputs ) == 1 :
2260
- ret = [broadcast_arrays (new_inputs [0 ], * node . inputs )[ 0 ] ]
2246
+ ret = [alloc_like (new_inputs [0 ], node_output , fgraph ) ]
2261
2247
else :
2262
- ret = [broadcast_arrays (add (* new_inputs ), * node . inputs )[ 0 ] ]
2248
+ ret = [alloc_like (add (* new_inputs ), node_output , fgraph ) ]
2263
2249
2264
2250
# The dtype should not be changed. It can happen if the input
2265
2251
# that was forcing upcasting was equal to 0.
@@ -2377,7 +2363,7 @@ def local_log1p(fgraph, node):
2377
2363
ninp = nonconsts [0 ]
2378
2364
if ninp .dtype != log_arg .type .dtype :
2379
2365
ninp = ninp .astype (node .outputs [0 ].dtype )
2380
- return [broadcast_arrays (log1p (ninp ), * scalar_inputs ) [0 ]]
2366
+ return [alloc_like (log1p (ninp ), node . outputs [0 ], fgraph ) ]
2381
2367
2382
2368
elif log_arg .owner and log_arg .owner .op == sub :
2383
2369
one = extract_constant (log_arg .owner .inputs [0 ], only_process_constants = True )
@@ -3573,10 +3559,11 @@ def local_reciprocal_1_plus_exp(fgraph, node):
3573
3559
if nonconsts [0 ].owner and nonconsts [0 ].owner .op == exp :
3574
3560
if scalars_ and np .allclose (np .sum (scalars_ ), 1 ):
3575
3561
out = [
3576
- broadcast_arrays (
3562
+ alloc_like (
3577
3563
sigmoid (neg (nonconsts [0 ].owner .inputs [0 ])),
3578
- * scalar_inputs ,
3579
- )[0 ]
3564
+ node .outputs [0 ],
3565
+ fgraph ,
3566
+ )
3580
3567
]
3581
3568
# keep combined stack traces of
3582
3569
# exp(x): nonconsts[0],
0 commit comments