24
24
"""
25
25
26
26
import logging
27
- from typing import TYPE_CHECKING , Optional , Union
27
+ from typing import Union
28
28
29
29
import numpy as np
30
30
66
66
)
67
67
from pytensor .tensor .elemwise import DimShuffle , Elemwise
68
68
from pytensor .tensor .exceptions import NotScalarConstantError
69
- from pytensor .tensor .extra_ops import broadcast_shape , broadcast_to
69
+ from pytensor .tensor .extra_ops import broadcast_arrays
70
70
from pytensor .tensor .math import Sum , add
71
71
from pytensor .tensor .math import all as at_all
72
72
from pytensor .tensor .math import eq
73
- from pytensor .tensor .shape import Shape_i
73
+ from pytensor .tensor .shape import Shape_i , shape_padleft
74
74
from pytensor .tensor .sort import TopKOp
75
75
from pytensor .tensor .type import DenseTensorType , TensorType
76
76
from pytensor .tensor .var import TensorConstant , TensorVariable
77
77
from pytensor .utils import NoDuplicateOptWarningFilter
78
78
79
79
80
- if TYPE_CHECKING :
81
- from pytensor .tensor .rewriting .shape import ShapeFeature
82
-
83
-
84
80
_logger = logging .getLogger ("pytensor.tensor.rewriting.basic" )
85
81
_logger .addFilter (NoDuplicateOptWarningFilter ())
86
82
@@ -262,31 +258,16 @@ def local_scalar_tensor_scalar(fgraph, node):
262
258
def local_elemwise_alloc (fgraph , node ):
263
259
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
264
260
265
- `Alloc`\s are effectively a type of `Elemwise` operation
266
- (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
267
- this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
268
- `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
269
- broadcasts).
270
-
271
- In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
272
- `Alloc`\s.
273
-
274
261
The rewrite essentially performs the following replacement:
275
- ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
276
- when ``y.shape`` for some input ``y`` (or the combined shapes of the
277
- non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
262
+ ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
278
263
279
- In it's current form, it also explicitly accounts for `DimShuffle`\s of
264
+ In its current form, it also explicitly accounts for `DimShuffle`\s of
280
265
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
281
266
introduces them as a canonicalization of `Alloc`'s with leading
282
267
broadcastable dimensions.
283
268
"""
284
- # Rewrite is only applicable when there are at least two inputs
285
269
if len (node .inputs ) == 1 :
286
- return False
287
-
288
- if len (node .outputs ) > 1 :
289
- return False
270
+ return None
290
271
291
272
def dimshuffled_alloc (i ):
292
273
return (
@@ -306,76 +287,40 @@ def dimshuffled_alloc(i):
306
287
if len (alloc_idxs ) == 0 :
307
288
return False
308
289
309
- # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
310
- # baseline for the dimensions.
311
- ref_var_idx = None
312
- for idx , i in enumerate (node .inputs ):
313
- if i .type .broadcastable == node .outputs [0 ].type .broadcastable :
314
- # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
315
- # `Alloc`, so that all `Alloc`s can be rewritten.
316
- if idx not in alloc_idxs :
317
- ref_var_idx = idx
318
- break
319
-
320
- # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
321
- if ref_var_idx is None :
322
- for idx , i in enumerate (node .inputs ):
323
- # XXX: This broadcastable comparison doesn't work
324
- if (
325
- i .type .broadcastable == node .outputs [0 ].type .broadcastable
326
- ) and idx in alloc_idxs :
327
- ref_var_idx = idx
328
- break
329
-
330
- if not hasattr (fgraph , "shape_feature" ):
331
- return False
332
-
333
- input_shapes = [
334
- tuple (fgraph .shape_feature .get_shape (i , j ) for j in range (i .type .ndim ))
335
- for i in node .inputs
336
- ]
337
- bcasted_shape = broadcast_shape (
338
- * input_shapes ,
339
- arrays_are_shapes = True ,
340
- )
341
-
342
290
new_inputs = list (node .inputs )
343
291
for idx in alloc_idxs :
344
292
i = node .inputs [idx ]
345
293
346
- # Remove `Alloc`
294
+ # Remove simple `Alloc`
347
295
if isinstance (i .owner .op , Alloc ):
348
- new_alloc = broadcast_to ( i .owner .inputs [0 ], bcasted_shape )
296
+ new_inp = i .owner .inputs [0 ]
349
297
350
- # TODO FIXME: This shouldn't be handled here.
351
- # `DimShuffle`s should be lifted through `Alloc`s
352
- # by other, more general rewrites.
353
- # Remove `Alloc` in `DimShuffle`
298
+ # Remove `Dimshuffle(Alloc)`
354
299
elif isinstance (i .owner .op , DimShuffle ):
355
300
old_alloc = i .owner .inputs [0 ]
356
- new_alloc = old_alloc .owner .inputs [0 ]
301
+ old_alloc_inp = old_alloc .owner .inputs [0 ]
302
+ missing_ndims = old_alloc .type .ndim - old_alloc_inp .type .ndim
303
+ if missing_ndims > 0 :
304
+ # The `Alloc` added new dimensions to the left.
305
+ # We replace those cases with a `DimShuffle` here.
306
+ # Nested dimshuffles will be merged later by other rewrites.
307
+ old_alloc_inp = shape_padleft (old_alloc_inp , missing_ndims )
357
308
# We need to keep the old `DimShuffle`. It could swap axes or
358
309
# add dimensions anywhere.
359
- if new_alloc .ndim != old_alloc .ndim :
360
- # The `Alloc` can add dimensions to the value.
361
- # We replace those cases with a `DimShuffle` here.
362
- nb_dim_to_add = old_alloc .ndim - new_alloc .ndim
363
- new_alloc = new_alloc .dimshuffle (
364
- ["x" ] * nb_dim_to_add + list (range (new_alloc .ndim ))
365
- )
366
- new_alloc = broadcast_to (i .owner .op (new_alloc ), bcasted_shape )
310
+ new_inp = i .owner .op (old_alloc_inp )
367
311
368
- copy_stack_trace (i , new_alloc )
369
- new_inputs [idx ] = new_alloc
312
+ copy_stack_trace (i , new_inp )
313
+ new_inputs [idx ] = new_inp
370
314
371
- # If this assert is triggered, it means we are recreating an equivalent graph
372
- # which would result in cyclical merge rewrites.
373
- if all (new is old for new , old in zip (new_inputs , node .inputs )):
374
- return
315
+ new_outs = node .op (* new_inputs , return_list = True )
375
316
376
- ret = node .op (* new_inputs , return_list = True )
377
- copy_stack_trace (node .outputs , ret )
378
- return ret
317
+ if new_outs [0 ].type .broadcastable != node .outputs [0 ].type .broadcastable :
318
+ new_outs = [
319
+ alloc_like (new_out , node .outputs [0 ], fgraph ) for new_out in new_outs
320
+ ]
321
+
322
+ copy_stack_trace (node .outputs , new_outs )
323
+ return new_outs
379
324
380
325
381
326
@register_canonicalize ("shape_unsafe" )
@@ -407,6 +352,7 @@ def local_fill_sink(fgraph, node):
407
352
408
353
# The newly created node c doesn't has 'clients',
409
354
# so this iteration is took place with node.outputs[0]
355
+ # TODO: This should just be a WalkingGraphRewrite!
410
356
replacements = {node .outputs [0 ]: c }
411
357
for client , cl_idx in fgraph .clients [node .outputs [0 ]]:
412
358
if (
@@ -439,23 +385,15 @@ def local_fill_to_alloc(fgraph, node):
439
385
with their dependencies on those tensors' shapes, and sometimes those
440
386
shapes can be computed without needing to compute the tensors themselves.
441
387
442
- XXX: This rewrite can produce inconsistent results, so do *not* consider
443
- making it a canonicalization until those inconsistencies are
444
- resolved/justified.
388
+ Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
389
+ which could mask shape errors.
445
390
"""
446
391
shape_ref , values_ref = node .inputs
447
392
out_type = node .outputs [0 ].type
448
393
449
394
if values_ref .type .broadcastable == out_type .broadcastable :
450
395
# The assumption here is that `values_ref` already has the same shape
451
396
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
452
-
453
- # XXX FIXME TODO: The only way this can be determined is if one
454
- # absolutely knows that the shapes of `shape_ref` and `values_ref` are
455
- # equal.
456
- # This is an old rewrite, and it's only a
457
- # "specialization/stabilization", so we're going to leave it be for
458
- # now.
459
397
return [values_ref ]
460
398
461
399
if shape_ref .type .broadcastable == out_type .broadcastable :
@@ -466,6 +404,9 @@ def local_fill_to_alloc(fgraph, node):
466
404
copy_stack_trace (node .outputs [0 ], o )
467
405
return [o ]
468
406
407
+ # The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
408
+ # TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))
409
+
469
410
return
470
411
471
412
@@ -1015,36 +956,30 @@ def local_sum_make_vector(fgraph, node):
1015
956
return [element_sum ]
1016
957
1017
958
1018
- @register_useless ("local_remove_switch_const_cond " )
1019
- @register_canonicalize ("fast_compile" , "local_remove_switch_const_cond " )
1020
- @register_specialize
1021
- @node_rewriter ([Elemwise ])
959
+ @register_useless ("shape_unsafe " )
960
+ @register_canonicalize ("fast_compile" , "shape_unsafe " )
961
+ @register_specialize ( "shape_unsafe" )
962
+ @node_rewriter ([switch ])
1022
963
def local_useless_switch (fgraph , node ):
1023
964
"""
1024
965
This rewrite makes the following changes in a graph:
1025
966
1026
- at. switch(cond, left, right) ->
1027
- if cond is constant and cond == 0: right
1028
- if cond is constant and cond != 0: left
1029
- if left is right -> left
967
+ switch(cond, left, right) ->
968
+ if cond is constant and cond == 0: right
969
+ if cond is constant and cond != 0: left
970
+ if left is right -> left
1030
971
1031
972
and
1032
973
1033
- at. switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
974
+ switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
1034
975
1035
976
"""
1036
- if not isinstance (node .op .scalar_op , aes .Switch ):
1037
- return False
1038
-
1039
- shape_feature : Optional ["ShapeFeature" ] = getattr (fgraph , "shape_feature" , None )
1040
-
1041
- if shape_feature is None :
1042
- return False
1043
977
1044
978
left = node .inputs [1 ]
1045
979
right = node .inputs [2 ]
1046
980
cond_var = node .inputs [0 ]
1047
981
cond = extract_constant (cond_var , only_process_constants = True )
982
+ out_bcast = node .outputs [0 ].type .broadcastable
1048
983
1049
984
if (isinstance (cond , np .ndarray ) and cond .ndim == 0 ) or isinstance (
1050
985
cond , (np .number , np .bool_ )
@@ -1059,14 +994,8 @@ def local_useless_switch(fgraph, node):
1059
994
else :
1060
995
out = correct_out
1061
996
1062
- input_shapes = [
1063
- tuple (shape_feature .get_shape (inp , i ) for i in range (inp .type .ndim ))
1064
- for inp in node .inputs
1065
- ]
1066
-
1067
- out_shape = broadcast_shape (* input_shapes , arrays_are_shapes = True )
1068
-
1069
- out = alloc (out , * out_shape )
997
+ if out .type .broadcastable != out_bcast :
998
+ out = broadcast_arrays (out , * node .inputs )[0 ]
1070
999
1071
1000
# Copy over stacktrace from selected output to new output
1072
1001
copy_stack_trace (node .outputs + correct_out , out )
@@ -1076,10 +1005,10 @@ def local_useless_switch(fgraph, node):
1076
1005
if left == right :
1077
1006
# Note: No need to copy over stacktrace, because the input node
1078
1007
# already has its own stacktrace
1079
- if cond . type . is_super ( left .type ) :
1008
+ if left .type . broadcastable == out_bcast :
1080
1009
return [left ]
1081
1010
1082
- ret = fill ( cond , left )
1011
+ ret = broadcast_arrays ( left , cond )[ 0 ]
1083
1012
1084
1013
# Copy over stacktrace from switch output and correct branch
1085
1014
copy_stack_trace (node .outputs + left , ret )
0 commit comments