5
5
6
6
from pytensor import Variable
7
7
from pytensor .compile import optdb
8
- from pytensor .graph import Constant , FunctionGraph , node_rewriter
8
+ from pytensor .graph import Constant , FunctionGraph , node_rewriter , vectorize_graph
9
9
from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
10
10
from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
11
11
from pytensor .scalar import basic as ps
@@ -119,21 +119,43 @@ def local_subtensor_of_dot(fgraph, node):
119
119
the remaining entries of ``idxs`` (if any), modified to skip the
120
120
second-to-last dimension of ``B`` (because dot sums over this dimension).
121
121
"""
122
- if not isinstance (node .op , Subtensor ):
123
- return
124
- if not (node .inputs [0 ].owner and isinstance (node .inputs [0 ].owner .op , Dot )):
122
+ x , * idx_vars = node .inputs
123
+ if not (
124
+ x .owner is not None
125
+ and (
126
+ isinstance (x .owner .op , Dot )
127
+ or (
128
+ isinstance (x .owner .op , Blockwise )
129
+ and isinstance (x .owner .op .core_op , Dot )
130
+ )
131
+ )
132
+ ):
125
133
return
126
134
# If there is other node that use the outputs of the dot
127
135
# We don't want to compute twice the sub part.
128
- if len (fgraph .clients [node . inputs [ 0 ] ]) > 1 :
136
+ if len (fgraph .clients [x ]) > 1 :
129
137
return
130
138
131
- a = node .inputs [0 ].owner .inputs [0 ]
132
- b = node .inputs [0 ].owner .inputs [1 ]
139
+ a = x .owner .inputs [0 ]
140
+ b = x .owner .inputs [1 ]
141
+ idx_list = indices_from_subtensor (idx_vars , node .op .idx_list )
133
142
134
- idx_list = get_idx_list (node .inputs , node .op .idx_list )
143
+ batch_ndim = (
144
+ x .owner .op .batch_ndim (x .owner ) if isinstance (x .owner .op , Blockwise ) else 0
145
+ )
146
+
147
+ if batch_ndim :
148
+ batch_idx_list , idx_list = idx_list [:batch_ndim ], idx_list [batch_ndim :]
149
+ if not idx_list :
150
+ # Indexing only over batch dimensions of Blockwise, that can be handled by another rewrite
151
+ return None
152
+ # We perform the rest of the rewrite on dummy a, b that correspond to the core case
153
+ a = a .type .clone (shape = a .type .shape [batch_ndim :])()
154
+ b = b .type .clone (shape = b .type .shape [batch_ndim :])()
135
155
136
- num_a_indices = min (a .ndim - 1 , len (idx_list ))
156
+ a_ndim = a .ndim
157
+ b_ndim = b .ndim
158
+ num_a_indices = min (a_ndim - 1 , len (idx_list ))
137
159
a_indices = idx_list [:num_a_indices ]
138
160
b_indices = idx_list [num_a_indices :]
139
161
@@ -142,26 +164,22 @@ def local_subtensor_of_dot(fgraph, node):
142
164
# This wasn't necessary for a, because we just omitted the last index.
143
165
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144
166
# (dot also handles b.ndim < 2 as a special case)
145
- if b . ndim > 1 and len (b_indices ) >= b . ndim - 1 :
167
+ if b_ndim > 1 and len (b_indices ) >= b_ndim - 1 :
146
168
b_indices = (
147
- b_indices [: b . ndim - 2 ]
169
+ b_indices [: b_ndim - 2 ]
148
170
+ (slice (None , None , None ),)
149
- + b_indices [b . ndim - 2 :]
171
+ + b_indices [b_ndim - 2 :]
150
172
)
151
173
152
- a_sub = a .__getitem__ (tuple (a_indices ))
153
- b_sub = b .__getitem__ (tuple (b_indices )) if b_indices else b
174
+ a_sub = a [tuple (a_indices )]
175
+ b_sub = b [tuple (b_indices )] if b_indices else b
176
+ r = dot (a_sub , b_sub )
154
177
155
- # Copy over previous output stacktrace to a_sub and b_sub,
156
- # because an error in the subtensor operation (e.g. an index error)
157
- # on either a or b must correspond to an error in the
158
- # subtensor operation on their dot product.
159
- copy_stack_trace (node .outputs [0 ], [a_sub , b_sub ])
178
+ if batch_ndim :
179
+ # Replace dummy inputs by the original batch ones
180
+ r = vectorize_graph (r , replace = {a : x .owner .inputs [0 ], b : x .owner .inputs [1 ]})
181
+ r = r [tuple (batch_idx_list )]
160
182
161
- # Copy over previous output stacktrace and previous dot product stacktrace,
162
- # because an error here may correspond to an either in either the original
163
- # dot product, or in the dot product after the subtensor operation.
164
- r = dot (a_sub , b_sub )
165
183
copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], r )
166
184
167
185
return [r ]
0 commit comments