Skip to content

Commit c8eed58

Browse files
committed
Generalize dot rewrites to work with Blockwise
1 parent 9136947 commit c8eed58

File tree

2 files changed

+52
-33
lines changed

2 files changed

+52
-33
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
Prod,
4545
Sum,
4646
_conj,
47+
_dot,
4748
_inner_prod,
4849
_matrix_matrix_matmul,
4950
_matrix_vec_prod,
@@ -97,6 +98,7 @@
9798
register_useless,
9899
)
99100
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
101+
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
100102
from pytensor.tensor.shape import Shape, Shape_i
101103
from pytensor.tensor.subtensor import Subtensor
102104
from pytensor.tensor.type import (
@@ -174,21 +176,20 @@ def local_lift_transpose_through_dot(fgraph, node):
174176
These rewrites "lift" (propagate towards the inputs) `DimShuffle`
175177
through dot product. It allows to put the graph in a more standard shape,
176178
and to later merge consecutive `DimShuffle`\s.
177-
178-
The transformation should be apply whether or not the transpose is
179-
inplace. The newly-introduced transpositions are not inplace, this will
180-
be taken care of in a later rewrite phase.
181-
182179
"""
183-
if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)):
184-
return False
185-
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
180+
181+
if not (
182+
is_matrix_transpose(node.out)
183+
and node.inputs[0].owner
184+
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul))
185+
):
186186
return False
187+
187188
x, y = node.inputs[0].owner.inputs
188189

189-
if x.ndim == y.ndim == 2:
190+
if x.ndim >= y.ndim >= 2:
190191
# Output is dot product of transposed inputs in reverse order
191-
ret = [dot(y.T, x.T)]
192+
ret = [dot_op(y.mT, x.mT)]
192193

193194
# Copy over stack trace to output from result of dot-product
194195
copy_stack_trace(node.inputs[0], ret)

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pytensor import Variable
77
from pytensor.compile import optdb
8-
from pytensor.graph import Constant, FunctionGraph, node_rewriter
8+
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
99
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
1010
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
1111
from pytensor.scalar import basic as ps
@@ -119,21 +119,43 @@ def local_subtensor_of_dot(fgraph, node):
119119
the remaining entries of ``idxs`` (if any), modified to skip the
120120
second-to-last dimension of ``B`` (because dot sums over this dimension).
121121
"""
122-
if not isinstance(node.op, Subtensor):
123-
return
124-
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
122+
x, *idx_vars = node.inputs
123+
if not (
124+
x.owner is not None
125+
and (
126+
isinstance(x.owner.op, Dot)
127+
or (
128+
isinstance(x.owner.op, Blockwise)
129+
and isinstance(x.owner.op.core_op, Dot)
130+
)
131+
)
132+
):
125133
return
126134
# If there is other node that use the outputs of the dot
127135
# We don't want to compute twice the sub part.
128-
if len(fgraph.clients[node.inputs[0]]) > 1:
136+
if len(fgraph.clients[x]) > 1:
129137
return
130138

131-
a = node.inputs[0].owner.inputs[0]
132-
b = node.inputs[0].owner.inputs[1]
139+
a = x.owner.inputs[0]
140+
b = x.owner.inputs[1]
141+
idx_list = indices_from_subtensor(idx_vars, node.op.idx_list)
133142

134-
idx_list = get_idx_list(node.inputs, node.op.idx_list)
143+
batch_ndim = (
144+
x.owner.op.batch_ndim(x.owner) if isinstance(x.owner.op, Blockwise) else 0
145+
)
146+
147+
if batch_ndim:
148+
batch_idx_list, idx_list = idx_list[:batch_ndim], idx_list[batch_ndim:]
149+
if not idx_list:
150+
# Indexing only over batch dimensions of Blockwise, that can be handled by another rewrite
151+
return None
152+
# We perform the rest of the rewrite on dummy a, b that correspond to the core case
153+
a = a.type.clone(shape=a.type.shape[batch_ndim:])()
154+
b = b.type.clone(shape=b.type.shape[batch_ndim:])()
135155

136-
num_a_indices = min(a.ndim - 1, len(idx_list))
156+
a_ndim = a.ndim
157+
b_ndim = b.ndim
158+
num_a_indices = min(a_ndim - 1, len(idx_list))
137159
a_indices = idx_list[:num_a_indices]
138160
b_indices = idx_list[num_a_indices:]
139161

@@ -142,26 +164,22 @@ def local_subtensor_of_dot(fgraph, node):
142164
# This wasn't necessary for a, because we just omitted the last index.
143165
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144166
# (dot also handles b.ndim < 2 as a special case)
145-
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
167+
if b_ndim > 1 and len(b_indices) >= b_ndim - 1:
146168
b_indices = (
147-
b_indices[: b.ndim - 2]
169+
b_indices[: b_ndim - 2]
148170
+ (slice(None, None, None),)
149-
+ b_indices[b.ndim - 2 :]
171+
+ b_indices[b_ndim - 2 :]
150172
)
151173

152-
a_sub = a.__getitem__(tuple(a_indices))
153-
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
174+
a_sub = a[tuple(a_indices)]
175+
b_sub = b[tuple(b_indices)] if b_indices else b
176+
r = dot(a_sub, b_sub)
154177

155-
# Copy over previous output stacktrace to a_sub and b_sub,
156-
# because an error in the subtensor operation (e.g. an index error)
157-
# on either a or b must correspond to an error in the
158-
# subtensor operation on their dot product.
159-
copy_stack_trace(node.outputs[0], [a_sub, b_sub])
178+
if batch_ndim:
179+
# Replace dummy inputs by the original batch ones
180+
r = vectorize_graph(r, replace={a: x.owner.inputs[0], b: x.owner.inputs[1]})
181+
r = r[tuple(batch_idx_list)]
160182

161-
# Copy over previous output stacktrace and previous dot product stacktrace,
162-
# because an error here may correspond to an either in either the original
163-
# dot product, or in the dot product after the subtensor operation.
164-
r = dot(a_sub, b_sub)
165183
copy_stack_trace([node.outputs[0], node.inputs[0]], r)
166184

167185
return [r]

0 commit comments

Comments
 (0)