@@ -461,6 +461,41 @@ def local_subtensor_of_expand_dims(fgraph, node):
461
461
return [out ]
462
462
463
463
464
+ @register_canonicalize
465
+ @register_specialize
466
+ @node_rewriter ([Subtensor ])
467
+ def local_subtensor_of_squeeze (fgraph , node ):
468
+ """Lift subtensor through a squeeze operation"""
469
+ x , * idxs_vars = node .inputs
470
+ if not (
471
+ x .owner is not None
472
+ and isinstance (x .owner .op , DimShuffle )
473
+ and x .owner .op .is_squeeze
474
+ ):
475
+ return None
476
+
477
+ [x_before_squeeze ] = x .owner .inputs
478
+ idxs = indices_from_subtensor (idxs_vars , node .op .idx_list )
479
+ dropped_dims = x .owner .op .drop
480
+
481
+ # Apply indices directly on x
482
+ # Add empty slices on the axis that squeeze would have removed
483
+ new_idxs = np .insert (np .array (idxs , dtype = object ), dropped_dims , slice (None ))
484
+ x_indexed = x_before_squeeze [tuple (new_idxs )]
485
+
486
+ # Reapply squeeze
487
+ # Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims
488
+ new_dropped_dims = np .array (dropped_dims )
489
+ for i , new_idx in reversed (tuple (enumerate (new_idxs ))):
490
+ if not isinstance (new_idx , slice ):
491
+ # If it's not a slice, it's an integer which drops the dimension
492
+ new_dropped_dims [new_dropped_dims > i ] -= 1
493
+ new_x = x_indexed .squeeze (tuple (new_dropped_dims ))
494
+
495
+ copy_stack_trace (x , new_x )
496
+ return [new_x ]
497
+
498
+
464
499
@register_canonicalize
465
500
@register_specialize
466
501
@node_rewriter ([Subtensor ])
0 commit comments