Skip to content

Commit 69268d9

Browse files
committed
Fix missing rank scalar
1 parent 916faed commit 69268d9

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,13 @@ def median(x, axis=None, keepdims=False):
10881088
ov_opset.constant([0], Type.i32).output(0),
10891089
ov_opset.constant([0], Type.i32).output(0),
10901090
).output(0)
1091+
x_rank_scalar = ov_opset.squeeze(
1092+
x_rank, ov_opset.constant([0], Type.i32).output(0)
1093+
).output(0)
10911094
axis_as_range = ov_opset.range(
10921095
ov_opset.constant(0, Type.i32).output(0),
1093-
x_rank,
1094-
ov_opset.constant([1], Type.i32).output(0),
1096+
x_rank_scalar,
1097+
ov_opset.constant(1, Type.i32).output(0),
10951098
"i32",
10961099
).output(0)
10971100
axis_compare = ov_opset.not_equal(
@@ -1137,15 +1140,19 @@ def median(x, axis=None, keepdims=False):
11371140
# negative axis values are incompatible with ov_opset.gather axis arguement,
11381141
# convert the values
11391142
if axis < 0:
1143+
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
11401144
x_rank = ov_opset.gather(
1141-
ov_opset.shape_of(x, Type.i32).output(0),
1145+
ov_opset.shape_of(x_shape, Type.i32).output(0),
11421146
ov_opset.constant([0], Type.i32).output(0),
11431147
ov_opset.constant([0], Type.i32).output(0),
11441148
).output(0)
1149+
x_rank_scalar = ov_opset.squeeze(
1150+
x_rank, ov_opset.constant([0], Type.i32).output(0)
1151+
).output(0)
11451152
axis_as_range = ov_opset.range(
11461153
ov_opset.constant(0, Type.i32).output(0),
1147-
x_rank,
1148-
ov_opset.constant([1], Type.i32).output(0),
1154+
x_rank_scalar,
1155+
ov_opset.constant(1, Type.i32).output(0),
11491156
"i32",
11501157
).output(0)
11511158
ov_axis_positive = ov_opset.gather(

0 commit comments

Comments
 (0)