Skip to content

Commit d3efba4

Browse files
committed
Simplify rewrites by assuming Elemwise / Alloc shapes are correct
1 parent 7cb078d commit d3efba4

File tree

2 files changed

+86
-150
lines changed

2 files changed

+86
-150
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 47 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
import logging
27-
from typing import TYPE_CHECKING, Optional, Union
27+
from typing import Union
2828

2929
import numpy as np
3030

@@ -66,21 +66,17 @@
6666
)
6767
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6868
from pytensor.tensor.exceptions import NotScalarConstantError
69-
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
69+
from pytensor.tensor.extra_ops import broadcast_arrays
7070
from pytensor.tensor.math import Sum, add
7171
from pytensor.tensor.math import all as at_all
7272
from pytensor.tensor.math import eq
73-
from pytensor.tensor.shape import Shape_i
73+
from pytensor.tensor.shape import Shape_i, shape_padleft
7474
from pytensor.tensor.sort import TopKOp
7575
from pytensor.tensor.type import DenseTensorType, TensorType
7676
from pytensor.tensor.var import TensorConstant, TensorVariable
7777
from pytensor.utils import NoDuplicateOptWarningFilter
7878

7979

80-
if TYPE_CHECKING:
81-
from pytensor.tensor.rewriting.shape import ShapeFeature
82-
83-
8480
_logger = logging.getLogger("pytensor.tensor.rewriting.basic")
8581
_logger.addFilter(NoDuplicateOptWarningFilter())
8682

@@ -262,31 +258,16 @@ def local_scalar_tensor_scalar(fgraph, node):
262258
def local_elemwise_alloc(fgraph, node):
263259
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
264260
265-
`Alloc`\s are effectively a type of `Elemwise` operation
266-
(e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
267-
this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
268-
`Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
269-
broadcasts).
270-
271-
In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
272-
`Alloc`\s.
273-
274261
The rewrite essentially performs the following replacement:
275-
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
276-
when ``y.shape`` for some input ``y`` (or the combined shapes of the
277-
non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
262+
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
278263
279-
In it's current form, it also explicitly accounts for `DimShuffle`\s of
264+
In its current form, it also explicitly accounts for `DimShuffle`\s of
280265
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
281266
introduces them as a canonicalization of `Alloc`'s with leading
282267
broadcastable dimensions.
283268
"""
284-
# Rewrite is only applicable when there are at least two inputs
285269
if len(node.inputs) == 1:
286-
return False
287-
288-
if len(node.outputs) > 1:
289-
return False
270+
return None
290271

291272
def dimshuffled_alloc(i):
292273
return (
@@ -306,76 +287,40 @@ def dimshuffled_alloc(i):
306287
if len(alloc_idxs) == 0:
307288
return False
308289

309-
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
310-
# baseline for the dimensions.
311-
ref_var_idx = None
312-
for idx, i in enumerate(node.inputs):
313-
if i.type.broadcastable == node.outputs[0].type.broadcastable:
314-
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
315-
# `Alloc`, so that all `Alloc`s can be rewritten.
316-
if idx not in alloc_idxs:
317-
ref_var_idx = idx
318-
break
319-
320-
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
321-
if ref_var_idx is None:
322-
for idx, i in enumerate(node.inputs):
323-
# XXX: This broadcastable comparison doesn't work
324-
if (
325-
i.type.broadcastable == node.outputs[0].type.broadcastable
326-
) and idx in alloc_idxs:
327-
ref_var_idx = idx
328-
break
329-
330-
if not hasattr(fgraph, "shape_feature"):
331-
return False
332-
333-
input_shapes = [
334-
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
335-
for i in node.inputs
336-
]
337-
bcasted_shape = broadcast_shape(
338-
*input_shapes,
339-
arrays_are_shapes=True,
340-
)
341-
342290
new_inputs = list(node.inputs)
343291
for idx in alloc_idxs:
344292
i = node.inputs[idx]
345293

346-
# Remove `Alloc`
294+
# Remove simple `Alloc`
347295
if isinstance(i.owner.op, Alloc):
348-
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
296+
new_inp = i.owner.inputs[0]
349297

350-
# TODO FIXME: This shouldn't be handled here.
351-
# `DimShuffle`s should be lifted through `Alloc`s
352-
# by other, more general rewrites.
353-
# Remove `Alloc` in `DimShuffle`
298+
# Remove `Dimshuffle(Alloc)`
354299
elif isinstance(i.owner.op, DimShuffle):
355300
old_alloc = i.owner.inputs[0]
356-
new_alloc = old_alloc.owner.inputs[0]
301+
old_alloc_inp = old_alloc.owner.inputs[0]
302+
missing_ndims = old_alloc.type.ndim - old_alloc_inp.type.ndim
303+
if missing_ndims > 0:
304+
# The `Alloc` added new dimensions to the left.
305+
# We replace those cases with a `DimShuffle` here.
306+
# Nested dimshuffles will be merged later by other rewrites.
307+
old_alloc_inp = shape_padleft(old_alloc_inp, missing_ndims)
357308
# We need to keep the old `DimShuffle`. It could swap axes or
358309
# add dimensions anywhere.
359-
if new_alloc.ndim != old_alloc.ndim:
360-
# The `Alloc` can add dimensions to the value.
361-
# We replace those cases with a `DimShuffle` here.
362-
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
363-
new_alloc = new_alloc.dimshuffle(
364-
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
365-
)
366-
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
310+
new_inp = i.owner.op(old_alloc_inp)
367311

368-
copy_stack_trace(i, new_alloc)
369-
new_inputs[idx] = new_alloc
312+
copy_stack_trace(i, new_inp)
313+
new_inputs[idx] = new_inp
370314

371-
# If this assert is triggered, it means we are recreating an equivalent graph
372-
# which would result in cyclical merge rewrites.
373-
if all(new is old for new, old in zip(new_inputs, node.inputs)):
374-
return
315+
new_outs = node.op(*new_inputs, return_list=True)
375316

376-
ret = node.op(*new_inputs, return_list=True)
377-
copy_stack_trace(node.outputs, ret)
378-
return ret
317+
if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable:
318+
new_outs = [
319+
alloc_like(new_out, node.outputs[0], fgraph) for new_out in new_outs
320+
]
321+
322+
copy_stack_trace(node.outputs, new_outs)
323+
return new_outs
379324

380325

381326
@register_canonicalize("shape_unsafe")
@@ -407,6 +352,7 @@ def local_fill_sink(fgraph, node):
407352

408353
# The newly created node c doesn't has 'clients',
409354
# so this iteration is took place with node.outputs[0]
355+
# TODO: This should just be a WalkingGraphRewrite!
410356
replacements = {node.outputs[0]: c}
411357
for client, cl_idx in fgraph.clients[node.outputs[0]]:
412358
if (
@@ -439,23 +385,15 @@ def local_fill_to_alloc(fgraph, node):
439385
with their dependencies on those tensors' shapes, and sometimes those
440386
shapes can be computed without needing to compute the tensors themselves.
441387
442-
XXX: This rewrite can produce inconsistent results, so do *not* consider
443-
making it a canonicalization until those inconsistencies are
444-
resolved/justified.
388+
Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
389+
which could mask shape errors.
445390
"""
446391
shape_ref, values_ref = node.inputs
447392
out_type = node.outputs[0].type
448393

449394
if values_ref.type.broadcastable == out_type.broadcastable:
450395
# The assumption here is that `values_ref` already has the same shape
451396
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
452-
453-
# XXX FIXME TODO: The only way this can be determined is if one
454-
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
455-
# equal.
456-
# This is an old rewrite, and it's only a
457-
# "specialization/stabilization", so we're going to leave it be for
458-
# now.
459397
return [values_ref]
460398

461399
if shape_ref.type.broadcastable == out_type.broadcastable:
@@ -466,6 +404,9 @@ def local_fill_to_alloc(fgraph, node):
466404
copy_stack_trace(node.outputs[0], o)
467405
return [o]
468406

407+
# The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
408+
# TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))
409+
469410
return
470411

471412

@@ -1015,36 +956,30 @@ def local_sum_make_vector(fgraph, node):
1015956
return [element_sum]
1016957

1017958

1018-
@register_useless("local_remove_switch_const_cond")
1019-
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
1020-
@register_specialize
1021-
@node_rewriter([Elemwise])
959+
@register_useless("shape_unsafe")
960+
@register_canonicalize("fast_compile", "shape_unsafe")
961+
@register_specialize("shape_unsafe")
962+
@node_rewriter([switch])
1022963
def local_useless_switch(fgraph, node):
1023964
"""
1024965
This rewrite makes the following changes in a graph:
1025966
1026-
at.switch(cond, left, right) ->
1027-
if cond is constant and cond == 0: right
1028-
if cond is constant and cond != 0: left
1029-
if left is right -> left
967+
switch(cond, left, right) ->
968+
if cond is constant and cond == 0: right
969+
if cond is constant and cond != 0: left
970+
if left is right -> left
1030971
1031972
and
1032973
1033-
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
974+
switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
1034975
1035976
"""
1036-
if not isinstance(node.op.scalar_op, aes.Switch):
1037-
return False
1038-
1039-
shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None)
1040-
1041-
if shape_feature is None:
1042-
return False
1043977

1044978
left = node.inputs[1]
1045979
right = node.inputs[2]
1046980
cond_var = node.inputs[0]
1047981
cond = extract_constant(cond_var, only_process_constants=True)
982+
out_bcast = node.outputs[0].type.broadcastable
1048983

1049984
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
1050985
cond, (np.number, np.bool_)
@@ -1059,14 +994,8 @@ def local_useless_switch(fgraph, node):
1059994
else:
1060995
out = correct_out
1061996

1062-
input_shapes = [
1063-
tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim))
1064-
for inp in node.inputs
1065-
]
1066-
1067-
out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True)
1068-
1069-
out = alloc(out, *out_shape)
997+
if out.type.broadcastable != out_bcast:
998+
out = broadcast_arrays(out, *node.inputs)[0]
1070999

10711000
# Copy over stacktrace from selected output to new output
10721001
copy_stack_trace(node.outputs + correct_out, out)
@@ -1076,10 +1005,10 @@ def local_useless_switch(fgraph, node):
10761005
if left == right:
10771006
# Note: No need to copy over stacktrace, because the input node
10781007
# already has its own stacktrace
1079-
if cond.type.is_super(left.type):
1008+
if left.type.broadcastable == out_bcast:
10801009
return [left]
10811010

1082-
ret = fill(cond, left)
1011+
ret = broadcast_arrays(left, cond)[0]
10831012

10841013
# Copy over stacktrace from switch output and correct branch
10851014
copy_stack_trace(node.outputs + left, ret)

0 commit comments

Comments
 (0)