Skip to content

Commit 1caf977

Browse files
committed
Moving shape check to make_node
1 parent 9f33117 commit 1caf977

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

pytensor/xtensor/math.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,20 @@ def make_node(self, x, y):
164164

165165
x_shape_dict = dict(zip(x.type.dims, x.type.shape))
166166
y_shape_dict = dict(zip(y.type.dims, y.type.shape))
167-
shape_dict = {**x_shape_dict, **y_shape_dict}
167+
168+
# Check for dimension size mismatches (concrete only)
169+
for dim in self.dims:
170+
x_shape = x_shape_dict.get(dim, None)
171+
y_shape = y_shape_dict.get(dim, None)
172+
if (
173+
isinstance(x_shape, int)
174+
and isinstance(y_shape, int)
175+
and x_shape != y_shape
176+
):
177+
raise ValueError(f"Size of dim '{dim}' does not match")
168178

169179
# Determine output dimensions
180+
shape_dict = {**x_shape_dict, **y_shape_dict}
170181
out_dims = tuple(d for d in shape_dict if d not in self.dims)
171182

172183
# Determine output shape
@@ -231,17 +242,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
231242
if d not in union:
232243
raise ValueError(f"Dimension {d} not found in either input")
233244

234-
# Check for dimension size mismatches (concrete only)
235-
for dim in intersection:
236-
x_idx = x.type.dims.index(dim)
237-
y_idx = y.type.dims.index(dim)
238-
if (
239-
isinstance(x.type.shape[x_idx], int)
240-
and isinstance(y.type.shape[y_idx], int)
241-
and x.type.shape[x_idx] != y.type.shape[y_idx]
242-
):
243-
raise ValueError(f"Size of dim '{dim}' does not match")
244-
245245
result = XDot(dims=tuple(dim_set))(x, y)
246246

247247
return result

0 commit comments

Comments
 (0)