Skip to content

Commit d92b60a

Browse files
committed
fixed and passed local testing. Submit median for PR
1 parent 69268d9 commit d92b60a

File tree

1 file changed

+70
-95
lines changed

1 file changed

+70
-95
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 70 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,127 +1052,101 @@ def median(x, axis=None, keepdims=False):
10521052

10531053
x = get_ov_output(x)
10541054
x_type = x.get_element_type()
1055+
x_rank_org = x.get_partial_shape().rank.get_length()
10551056
if x_type == Type.boolean or x_type.is_integral():
1056-
x = ov_opset.convert(x, Type.f32).output(0)
1057-
x_type = x.get_element_type()
1057+
x_type = OPENVINO_DTYPES[config.floatx()]
1058+
x = ov_opset.convert(x, x_type).output(0)
1059+
10581060
x_shape_original = ov_opset.shape_of(x, Type.i32).output(0)
10591061

10601062
if axis is None:
10611063
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
10621064
x = ov_opset.reshape(x, flatten_shape, False).output(0)
10631065
axis = 0
1064-
ov_axis = get_ov_output(axis)
1066+
axis_norm = axis
1067+
ov_axis_positive = get_ov_output(axis)
10651068
flattened = True
1066-
k_value = ov_opset.gather(
1067-
ov_opset.shape_of(x, Type.i32).output(0),
1068-
ov_opset.constant([0], Type.i32).output(0),
1069-
ov_axis,
1070-
).output(0)
1069+
k_value = x.get_partial_shape().get_dimension(index=0).get_length()
10711070
elif isinstance(axis, int):
10721071
flattened = False
1073-
ov_axis = get_ov_output(axis)
1074-
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1075-
k_value = ov_opset.gather(
1076-
x_shape, ov_axis, ov_opset.constant([0], Type.i32).output(0)
1077-
).output(0)
1072+
x_rank = x.get_partial_shape().rank.get_length()
1073+
if axis < 0:
1074+
axis_norm = x_rank + axis
1075+
else:
1076+
axis_norm = axis
1077+
ov_axis_positive = ov_axis = get_ov_output(axis)
1078+
k_value = (
1079+
x.get_partial_shape().get_dimension(index=axis_norm).get_length()
1080+
)
10781081
else:
10791082
# where axis is tuple or list of integers, move 'axis' dims to the
10801083
# rightmost positions and flatten them
10811084
flattened = False
10821085
if isinstance(axis, (tuple, list)):
1083-
ov_axis = convert_to_tensor(axis)
1084-
else:
1085-
ov_axis = get_ov_output(axis)
1086-
x_rank = ov_opset.gather(
1087-
ov_opset.shape_of(x_shape_original, Type.i32).output(0),
1088-
ov_opset.constant([0], Type.i32).output(0),
1089-
ov_opset.constant([0], Type.i32).output(0),
1090-
).output(0)
1091-
x_rank_scalar = ov_opset.squeeze(
1092-
x_rank, ov_opset.constant([0], Type.i32).output(0)
1093-
).output(0)
1086+
ov_axis = axis = list(axis)
1087+
ov_axis = ov_opset.constant(axis, Type.i32).output(0)
1088+
x_rank = x.get_partial_shape().rank.get_length()
10941089
axis_as_range = ov_opset.range(
10951090
ov_opset.constant(0, Type.i32).output(0),
1096-
x_rank_scalar,
1091+
x_rank,
10971092
ov_opset.constant(1, Type.i32).output(0),
1098-
"i32",
1099-
).output(0)
1100-
axis_compare = ov_opset.not_equal(
1101-
ov_opset.unsqueeze(axis_as_range, 1).output(0),
1102-
ov_opset.unsqueeze(ov_axis, 0).output(0),
1103-
"NUMPY",
1104-
).output(0)
1105-
keep_axes = ov_opset.reduce_logical_or(
1106-
axis_compare, ov_opset.constant([1], Type.i32).output(0)
1093+
Type.i32,
11071094
).output(0)
1108-
nz = ov_opset.non_zero(keep_axes, Type.i32).output(0)
1109-
keep_axes = ov_opset.reduce_sum(
1110-
nz, ov_opset.constant([1], Type.i32).output(0)
1111-
).output(0)
1112-
reordered_axes = ov_opset.concat(
1113-
[keep_axes, ov_axis], ov_opset.constant([0], Type.i32).output(0)
1095+
# normalise any negative axes to their positive indices
1096+
ov_axis_positive = ov_opset.gather(
1097+
axis_as_range, ov_axis, ov_opset.constant([0], Type.i32)
11141098
).output(0)
1115-
x = ov_opset.transpose(x, reordered_axes).output(0)
1099+
# only move axis dims if tuple contains more than 1 axis
1100+
if ov_axis_positive.get_partial_shape().rank.get_length() > 1:
1101+
axis_compare = ov_opset.not_equal(
1102+
ov_opset.unsqueeze(axis_as_range, 1).output(0),
1103+
ov_opset.unsqueeze(ov_axis_positive, 0).output(0),
1104+
).output(0)
1105+
keep_axes = ov_opset.reduce_logical_or(
1106+
axis_compare, ov_opset.constant([1], Type.i32).output(0)
1107+
).output(0)
1108+
nz = ov_opset.non_zero(keep_axes, Type.i32).output(0)
1109+
keep_axes = ov_opset.reduce_sum(
1110+
nz, ov_opset.constant([1], Type.i32).output(0)
1111+
).output(0)
1112+
reordered_axes = ov_opset.concat(
1113+
[keep_axes, ov_axis_positive], 0
1114+
).output(0)
1115+
x = ov_opset.transpose(x, reordered_axes).output(0)
11161116

1117-
flat_rank = ov_opset.subtract(
1118-
x_rank, ov_opset.constant([1], Type.i32)
1119-
).output(0)
1120-
flatten_shape = ov_opset.broadcast(
1121-
ov_opset.constant([0], Type.i32).output(0), flat_rank
1122-
).output(0)
1123-
flatten_shape = ov_opset.scatter_elements_update(
1124-
flatten_shape,
1125-
ov_opset.constant([-1], Type.i32).output(0),
1126-
ov_opset.constant([-1], Type.i32).output(0),
1127-
0,
1128-
"sum",
1129-
).output(0)
1117+
flat_rank = ov_opset.subtract(
1118+
x_rank, ov_opset.constant([1], Type.i64).output(0)
1119+
).output(0)
1120+
flatten_shape = ov_opset.broadcast(
1121+
ov_opset.constant([0], Type.i32).output(0), flat_rank
1122+
).output(0)
1123+
flatten_shape = ov_opset.scatter_elements_update(
1124+
flatten_shape,
1125+
ov_opset.constant([-1], Type.i32).output(0),
1126+
ov_opset.constant([-1], Type.i32).output(0),
1127+
0,
1128+
"sum",
1129+
).output(0)
11301130

1131-
x = ov_opset.reshape(x, flatten_shape, True).output(0)
1131+
x = ov_opset.reshape(x, flatten_shape, True).output(0)
11321132
axis = -1
1133-
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1134-
k_value = ov_opset.gather(
1135-
x_shape,
1136-
ov_opset.constant([-1], Type.i32).output(0),
1137-
ov_opset.constant([0], Type.i32).output(0),
1138-
).output(0)
1139-
1140-
# negative axis values are incompatible with ov_opset.gather axis arguement,
1141-
# convert the values
1142-
if axis < 0:
1143-
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1144-
x_rank = ov_opset.gather(
1145-
ov_opset.shape_of(x_shape, Type.i32).output(0),
1146-
ov_opset.constant([0], Type.i32).output(0),
1147-
ov_opset.constant([0], Type.i32).output(0),
1148-
).output(0)
1149-
x_rank_scalar = ov_opset.squeeze(
1150-
x_rank, ov_opset.constant([0], Type.i32).output(0)
1151-
).output(0)
1152-
axis_as_range = ov_opset.range(
1153-
ov_opset.constant(0, Type.i32).output(0),
1154-
x_rank_scalar,
1155-
ov_opset.constant(1, Type.i32).output(0),
1156-
"i32",
1157-
).output(0)
1158-
ov_axis_positive = ov_opset.gather(
1159-
axis_as_range, ov_axis, ov_opset.constant([0], Type.i32)
1160-
).output(0)
1161-
else:
1162-
ov_axis_positive = ov_axis
1133+
x_rank = x.get_partial_shape().rank.get_length()
1134+
axis_norm = x_rank + axis
1135+
ov_axis_positive = get_ov_output(axis_norm)
1136+
k_value = (
1137+
x.get_partial_shape().get_dimension(index=axis_norm).get_length()
1138+
)
11631139

1164-
k_scalar = ov_opset.squeeze(
1165-
k_value, ov_opset.constant([0], Type.i32).output(0)
1166-
).output(0)
11671140
x_sorted = ov_opset.topk(
1168-
x, k_scalar, axis, "min", "value", stable=True
1141+
x, k_value, axis_norm, "min", "value", stable=True
11691142
).output(0)
1143+
k_value = ov_opset.convert(k_value, x_type).output(0)
11701144
half_index = ov_opset.floor(
1171-
ov_opset.divide(k_value, ov_opset.constant([2], Type.i32)).output(0)
1145+
ov_opset.divide(k_value, ov_opset.constant([2], x_type)).output(0)
11721146
).output(0)
11731147
half_index = ov_opset.convert(half_index, Type.i32).output(0)
1174-
x_mod = ov_opset.mod(k_value, ov_opset.constant([2], Type.i32)).output(0)
1175-
is_even = ov_opset.equal(x_mod, ov_opset.constant([0], Type.i32)).output(0)
1148+
x_mod = ov_opset.mod(k_value, ov_opset.constant([2], x_type)).output(0)
1149+
is_even = ov_opset.equal(x_mod, ov_opset.constant([0], x_type)).output(0)
11761150

11771151
med_0 = ov_opset.gather(x_sorted, half_index, ov_axis_positive).output(0)
11781152
med_1 = ov_opset.select(
@@ -1188,10 +1162,9 @@ def median(x, axis=None, keepdims=False):
11881162
).output(0)
11891163

11901164
median_odd = med_0
1191-
median_type = med_0.get_element_type()
11921165
median_even = ov_opset.divide(
11931166
ov_opset.add(med_1, med_0).output(0),
1194-
ov_opset.constant([2], median_type),
1167+
ov_opset.constant([2], x_type),
11951168
).output(0)
11961169

11971170
median_eval = ov_opset.select(is_even, median_even, median_odd).output(0)
@@ -1205,7 +1178,9 @@ def median(x, axis=None, keepdims=False):
12051178
median_eval, median_shape, False
12061179
).output(0)
12071180
else:
1208-
median_eval = ov_opset.unsqueeze(median_eval, ov_axis).output(0)
1181+
if median_eval.get_partial_shape().rank.get_length() != x_rank_org:
1182+
median_eval = ov_opset.unsqueeze(median_eval, ov_axis).output(0)
1183+
12091184
else:
12101185
median_eval = ov_opset.squeeze(median_eval, ov_axis_positive).output(0)
12111186

0 commit comments

Comments
 (0)