@@ -164,9 +164,20 @@ def make_node(self, x, y):
164
164
165
165
x_shape_dict = dict (zip (x .type .dims , x .type .shape ))
166
166
y_shape_dict = dict (zip (y .type .dims , y .type .shape ))
167
- shape_dict = {** x_shape_dict , ** y_shape_dict }
167
+
168
+ # Check for dimension size mismatches (concrete only)
169
+ for dim in self .dims :
170
+ x_shape = x_shape_dict .get (dim , None )
171
+ y_shape = y_shape_dict .get (dim , None )
172
+ if (
173
+ isinstance (x_shape , int )
174
+ and isinstance (y_shape , int )
175
+ and x_shape != y_shape
176
+ ):
177
+ raise ValueError (f"Size of dim '{ dim } ' does not match" )
168
178
169
179
# Determine output dimensions
180
+ shape_dict = {** x_shape_dict , ** y_shape_dict }
170
181
out_dims = tuple (d for d in shape_dict if d not in self .dims )
171
182
172
183
# Determine output shape
@@ -231,17 +242,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
231
242
if d not in union :
232
243
raise ValueError (f"Dimension { d } not found in either input" )
233
244
234
- # Check for dimension size mismatches (concrete only)
235
- for dim in intersection :
236
- x_idx = x .type .dims .index (dim )
237
- y_idx = y .type .dims .index (dim )
238
- if (
239
- isinstance (x .type .shape [x_idx ], int )
240
- and isinstance (y .type .shape [y_idx ], int )
241
- and x .type .shape [x_idx ] != y .type .shape [y_idx ]
242
- ):
243
- raise ValueError (f"Size of dim '{ dim } ' does not match" )
244
-
245
245
result = XDot (dims = tuple (dim_set ))(x , y )
246
246
247
247
return result
0 commit comments