@@ -1052,127 +1052,101 @@ def median(x, axis=None, keepdims=False):
1052
1052
1053
1053
x = get_ov_output (x )
1054
1054
x_type = x .get_element_type ()
1055
+ x_rank_org = x .get_partial_shape ().rank .get_length ()
1055
1056
if x_type == Type .boolean or x_type .is_integral ():
1056
- x = ov_opset .convert (x , Type .f32 ).output (0 )
1057
- x_type = x .get_element_type ()
1057
+ x_type = OPENVINO_DTYPES [config .floatx ()]
1058
+ x = ov_opset .convert (x , x_type ).output (0 )
1059
+
1058
1060
x_shape_original = ov_opset .shape_of (x , Type .i32 ).output (0 )
1059
1061
1060
1062
if axis is None :
1061
1063
flatten_shape = ov_opset .constant ([- 1 ], Type .i32 ).output (0 )
1062
1064
x = ov_opset .reshape (x , flatten_shape , False ).output (0 )
1063
1065
axis = 0
1064
- ov_axis = get_ov_output (axis )
1066
+ axis_norm = axis
1067
+ ov_axis_positive = get_ov_output (axis )
1065
1068
flattened = True
1066
- k_value = ov_opset .gather (
1067
- ov_opset .shape_of (x , Type .i32 ).output (0 ),
1068
- ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1069
- ov_axis ,
1070
- ).output (0 )
1069
+ k_value = x .get_partial_shape ().get_dimension (index = 0 ).get_length ()
1071
1070
elif isinstance (axis , int ):
1072
1071
flattened = False
1073
- ov_axis = get_ov_output (axis )
1074
- x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1075
- k_value = ov_opset .gather (
1076
- x_shape , ov_axis , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1077
- ).output (0 )
1072
+ x_rank = x .get_partial_shape ().rank .get_length ()
1073
+ if axis < 0 :
1074
+ axis_norm = x_rank + axis
1075
+ else :
1076
+ axis_norm = axis
1077
+ ov_axis_positive = ov_axis = get_ov_output (axis )
1078
+ k_value = (
1079
+ x .get_partial_shape ().get_dimension (index = axis_norm ).get_length ()
1080
+ )
1078
1081
else :
1079
1082
# where axis is tuple or list of integers, move 'axis' dims to the
1080
1083
# rightmost positions and flatten them
1081
1084
flattened = False
1082
1085
if isinstance (axis , (tuple , list )):
1083
- ov_axis = convert_to_tensor (axis )
1084
- else :
1085
- ov_axis = get_ov_output (axis )
1086
- x_rank = ov_opset .gather (
1087
- ov_opset .shape_of (x_shape_original , Type .i32 ).output (0 ),
1088
- ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1089
- ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1090
- ).output (0 )
1091
- x_rank_scalar = ov_opset .squeeze (
1092
- x_rank , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1093
- ).output (0 )
1086
+ ov_axis = axis = list (axis )
1087
+ ov_axis = ov_opset .constant (axis , Type .i32 ).output (0 )
1088
+ x_rank = x .get_partial_shape ().rank .get_length ()
1094
1089
axis_as_range = ov_opset .range (
1095
1090
ov_opset .constant (0 , Type .i32 ).output (0 ),
1096
- x_rank_scalar ,
1091
+ x_rank ,
1097
1092
ov_opset .constant (1 , Type .i32 ).output (0 ),
1098
- "i32" ,
1099
- ).output (0 )
1100
- axis_compare = ov_opset .not_equal (
1101
- ov_opset .unsqueeze (axis_as_range , 1 ).output (0 ),
1102
- ov_opset .unsqueeze (ov_axis , 0 ).output (0 ),
1103
- "NUMPY" ,
1104
- ).output (0 )
1105
- keep_axes = ov_opset .reduce_logical_or (
1106
- axis_compare , ov_opset .constant ([1 ], Type .i32 ).output (0 )
1093
+ Type .i32 ,
1107
1094
).output (0 )
1108
- nz = ov_opset .non_zero (keep_axes , Type .i32 ).output (0 )
1109
- keep_axes = ov_opset .reduce_sum (
1110
- nz , ov_opset .constant ([1 ], Type .i32 ).output (0 )
1111
- ).output (0 )
1112
- reordered_axes = ov_opset .concat (
1113
- [keep_axes , ov_axis ], ov_opset .constant ([0 ], Type .i32 ).output (0 )
1095
+ # normalise any negative axes to their positive indices
1096
+ ov_axis_positive = ov_opset .gather (
1097
+ axis_as_range , ov_axis , ov_opset .constant ([0 ], Type .i32 )
1114
1098
).output (0 )
1115
- x = ov_opset .transpose (x , reordered_axes ).output (0 )
1099
+ # only move axis dims if tuple contains more than 1 axis
1100
+ if ov_axis_positive .get_partial_shape ().rank .get_length () > 1 :
1101
+ axis_compare = ov_opset .not_equal (
1102
+ ov_opset .unsqueeze (axis_as_range , 1 ).output (0 ),
1103
+ ov_opset .unsqueeze (ov_axis_positive , 0 ).output (0 ),
1104
+ ).output (0 )
1105
+ keep_axes = ov_opset .reduce_logical_or (
1106
+ axis_compare , ov_opset .constant ([1 ], Type .i32 ).output (0 )
1107
+ ).output (0 )
1108
+ nz = ov_opset .non_zero (keep_axes , Type .i32 ).output (0 )
1109
+ keep_axes = ov_opset .reduce_sum (
1110
+ nz , ov_opset .constant ([1 ], Type .i32 ).output (0 )
1111
+ ).output (0 )
1112
+ reordered_axes = ov_opset .concat (
1113
+ [keep_axes , ov_axis_positive ], 0
1114
+ ).output (0 )
1115
+ x = ov_opset .transpose (x , reordered_axes ).output (0 )
1116
1116
1117
- flat_rank = ov_opset .subtract (
1118
- x_rank , ov_opset .constant ([1 ], Type .i32 )
1119
- ).output (0 )
1120
- flatten_shape = ov_opset .broadcast (
1121
- ov_opset .constant ([0 ], Type .i32 ).output (0 ), flat_rank
1122
- ).output (0 )
1123
- flatten_shape = ov_opset .scatter_elements_update (
1124
- flatten_shape ,
1125
- ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1126
- ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1127
- 0 ,
1128
- "sum" ,
1129
- ).output (0 )
1117
+ flat_rank = ov_opset .subtract (
1118
+ x_rank , ov_opset .constant ([1 ], Type .i64 ). output ( 0 )
1119
+ ).output (0 )
1120
+ flatten_shape = ov_opset .broadcast (
1121
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ), flat_rank
1122
+ ).output (0 )
1123
+ flatten_shape = ov_opset .scatter_elements_update (
1124
+ flatten_shape ,
1125
+ ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1126
+ ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1127
+ 0 ,
1128
+ "sum" ,
1129
+ ).output (0 )
1130
1130
1131
- x = ov_opset .reshape (x , flatten_shape , True ).output (0 )
1131
+ x = ov_opset .reshape (x , flatten_shape , True ).output (0 )
1132
1132
axis = - 1
1133
- x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1134
- k_value = ov_opset .gather (
1135
- x_shape ,
1136
- ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1137
- ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1138
- ).output (0 )
1139
-
1140
- # negative axis values are incompatible with ov_opset.gather axis arguement,
1141
- # convert the values
1142
- if axis < 0 :
1143
- x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1144
- x_rank = ov_opset .gather (
1145
- ov_opset .shape_of (x_shape , Type .i32 ).output (0 ),
1146
- ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1147
- ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1148
- ).output (0 )
1149
- x_rank_scalar = ov_opset .squeeze (
1150
- x_rank , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1151
- ).output (0 )
1152
- axis_as_range = ov_opset .range (
1153
- ov_opset .constant (0 , Type .i32 ).output (0 ),
1154
- x_rank_scalar ,
1155
- ov_opset .constant (1 , Type .i32 ).output (0 ),
1156
- "i32" ,
1157
- ).output (0 )
1158
- ov_axis_positive = ov_opset .gather (
1159
- axis_as_range , ov_axis , ov_opset .constant ([0 ], Type .i32 )
1160
- ).output (0 )
1161
- else :
1162
- ov_axis_positive = ov_axis
1133
+ x_rank = x .get_partial_shape ().rank .get_length ()
1134
+ axis_norm = x_rank + axis
1135
+ ov_axis_positive = get_ov_output (axis_norm )
1136
+ k_value = (
1137
+ x .get_partial_shape ().get_dimension (index = axis_norm ).get_length ()
1138
+ )
1163
1139
1164
- k_scalar = ov_opset .squeeze (
1165
- k_value , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1166
- ).output (0 )
1167
1140
x_sorted = ov_opset .topk (
1168
- x , k_scalar , axis , "min" , "value" , stable = True
1141
+ x , k_value , axis_norm , "min" , "value" , stable = True
1169
1142
).output (0 )
1143
+ k_value = ov_opset .convert (k_value , x_type ).output (0 )
1170
1144
half_index = ov_opset .floor (
1171
- ov_opset .divide (k_value , ov_opset .constant ([2 ], Type . i32 )).output (0 )
1145
+ ov_opset .divide (k_value , ov_opset .constant ([2 ], x_type )).output (0 )
1172
1146
).output (0 )
1173
1147
half_index = ov_opset .convert (half_index , Type .i32 ).output (0 )
1174
- x_mod = ov_opset .mod (k_value , ov_opset .constant ([2 ], Type . i32 )).output (0 )
1175
- is_even = ov_opset .equal (x_mod , ov_opset .constant ([0 ], Type . i32 )).output (0 )
1148
+ x_mod = ov_opset .mod (k_value , ov_opset .constant ([2 ], x_type )).output (0 )
1149
+ is_even = ov_opset .equal (x_mod , ov_opset .constant ([0 ], x_type )).output (0 )
1176
1150
1177
1151
med_0 = ov_opset .gather (x_sorted , half_index , ov_axis_positive ).output (0 )
1178
1152
med_1 = ov_opset .select (
@@ -1188,10 +1162,9 @@ def median(x, axis=None, keepdims=False):
1188
1162
).output (0 )
1189
1163
1190
1164
median_odd = med_0
1191
- median_type = med_0 .get_element_type ()
1192
1165
median_even = ov_opset .divide (
1193
1166
ov_opset .add (med_1 , med_0 ).output (0 ),
1194
- ov_opset .constant ([2 ], median_type ),
1167
+ ov_opset .constant ([2 ], x_type ),
1195
1168
).output (0 )
1196
1169
1197
1170
median_eval = ov_opset .select (is_even , median_even , median_odd ).output (0 )
@@ -1205,7 +1178,9 @@ def median(x, axis=None, keepdims=False):
1205
1178
median_eval , median_shape , False
1206
1179
).output (0 )
1207
1180
else :
1208
- median_eval = ov_opset .unsqueeze (median_eval , ov_axis ).output (0 )
1181
+ if median_eval .get_partial_shape ().rank .get_length () != x_rank_org :
1182
+ median_eval = ov_opset .unsqueeze (median_eval , ov_axis ).output (0 )
1183
+
1209
1184
else :
1210
1185
median_eval = ov_opset .squeeze (median_eval , ov_axis_positive ).output (0 )
1211
1186
0 commit comments