@@ -1088,10 +1088,13 @@ def median(x, axis=None, keepdims=False):
1088
1088
ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1089
1089
ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1090
1090
).output (0 )
1091
+ x_rank_scalar = ov_opset .squeeze (
1092
+ x_rank , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1093
+ ).output (0 )
1091
1094
axis_as_range = ov_opset .range (
1092
1095
ov_opset .constant (0 , Type .i32 ).output (0 ),
1093
- x_rank ,
1094
- ov_opset .constant ([ 1 ] , Type .i32 ).output (0 ),
1096
+ x_rank_scalar ,
1097
+ ov_opset .constant (1 , Type .i32 ).output (0 ),
1095
1098
"i32" ,
1096
1099
).output (0 )
1097
1100
axis_compare = ov_opset .not_equal (
@@ -1137,15 +1140,19 @@ def median(x, axis=None, keepdims=False):
1137
1140
# negative axis values are incompatible with ov_opset.gather axis arguement,
1138
1141
# convert the values
1139
1142
if axis < 0 :
1143
+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1140
1144
x_rank = ov_opset .gather (
1141
- ov_opset .shape_of (x , Type .i32 ).output (0 ),
1145
+ ov_opset .shape_of (x_shape , Type .i32 ).output (0 ),
1142
1146
ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1143
1147
ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1144
1148
).output (0 )
1149
+ x_rank_scalar = ov_opset .squeeze (
1150
+ x_rank , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1151
+ ).output (0 )
1145
1152
axis_as_range = ov_opset .range (
1146
1153
ov_opset .constant (0 , Type .i32 ).output (0 ),
1147
- x_rank ,
1148
- ov_opset .constant ([ 1 ] , Type .i32 ).output (0 ),
1154
+ x_rank_scalar ,
1155
+ ov_opset .constant (1 , Type .i32 ).output (0 ),
1149
1156
"i32" ,
1150
1157
).output (0 )
1151
1158
ov_axis_positive = ov_opset .gather (
0 commit comments