Skip to content

Commit 9136947

Browse files
committed
Lift subtensor through squeeze
1 parent 2427c3d commit 9136947

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,41 @@ def local_subtensor_of_expand_dims(fgraph, node):
461461
return [out]
462462

463463

464+
@register_canonicalize
465+
@register_specialize
466+
@node_rewriter([Subtensor])
467+
def local_subtensor_of_squeeze(fgraph, node):
468+
"""Lift subtensor through a squeeze operation"""
469+
x, *idxs_vars = node.inputs
470+
if not (
471+
x.owner is not None
472+
and isinstance(x.owner.op, DimShuffle)
473+
and x.owner.op.is_squeeze
474+
):
475+
return None
476+
477+
[x_before_squeeze] = x.owner.inputs
478+
idxs = indices_from_subtensor(idxs_vars, node.op.idx_list)
479+
dropped_dims = x.owner.op.drop
480+
481+
# Apply indices directly on x
482+
# Add empty slices on the axis that squeeze would have removed
483+
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None))
484+
x_indexed = x_before_squeeze[tuple(new_idxs)]
485+
486+
# Reapply squeeze
487+
# Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims
488+
new_dropped_dims = np.array(dropped_dims)
489+
for i, new_idx in reversed(tuple(enumerate(new_idxs))):
490+
if not isinstance(new_idx, slice):
491+
# If it's not a slice, it's an integer which drops the dimension
492+
new_dropped_dims[new_dropped_dims > i] -= 1
493+
new_x = x_indexed.squeeze(tuple(new_dropped_dims))
494+
495+
copy_stack_trace(x, new_x)
496+
return [new_x]
497+
498+
464499
@register_canonicalize
465500
@register_specialize
466501
@node_rewriter([Subtensor])

0 commit comments

Comments
 (0)