Skip to content

Commit 08fae17

Browse files
[OpenVINO backend] support take_along_axis
1 parent be9b002 commit 08fae17

File tree

2 files changed

+111
-4
lines changed

2 files changed

+111
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ NumpyDtypeTest::test_std
5656
NumpyDtypeTest::test_subtract
5757
NumpyDtypeTest::test_sum
5858
NumpyDtypeTest::test_swapaxes
59-
NumpyDtypeTest::test_take_along_axis
6059
NumpyDtypeTest::test_tensordot_
6160
NumpyDtypeTest::test_tile
6261
NumpyDtypeTest::test_trace
@@ -145,7 +144,6 @@ NumpyTwoInputOpsCorrectnessTest::test_inner
145144
NumpyTwoInputOpsCorrectnessTest::test_linspace
146145
NumpyTwoInputOpsCorrectnessTest::test_logspace
147146
NumpyTwoInputOpsCorrectnessTest::test_quantile
148-
NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
149147
NumpyTwoInputOpsCorrectnessTest::test_tensordot
150148
NumpyTwoInputOpsCorrectnessTest::test_vdot
151149
NumpyOneInputOpsDynamicShapeTest::test_angle

keras/src/backend/openvino/numpy.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,10 +1442,119 @@ def take(x, indices, axis=None):
14421442

14431443

14441444
def take_along_axis(x, indices, axis=None):
1445-
raise NotImplementedError(
1446-
"`take_along_axis` is not supported with openvino backend"
1445+
x = get_ov_output(x)
1446+
indices = get_ov_output(indices)
1447+
1448+
if axis is None:
1449+
target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)
1450+
x_flat = ov_opset.reshape(x, target_shape, False).output(0)
1451+
indices_flat = ov_opset.reshape(indices, target_shape, False).output(0)
1452+
result = ov_opset.gather_elements(x_flat, indices_flat, 0).output(0)
1453+
return OpenVINOKerasTensor(result)
1454+
1455+
x_rank = len(x.get_partial_shape())
1456+
if axis < 0:
1457+
axis += x_rank
1458+
1459+
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1460+
indices_shape = ov_opset.shape_of(indices, Type.i32).output(0)
1461+
1462+
# fix negative indices by adding dimension size
1463+
axis_index = ov_opset.constant([axis], dtype=Type.i32).output(0)
1464+
zero_const = ov_opset.constant(0, dtype=Type.i32).output(0)
1465+
dim_size = ov_opset.gather(x_shape, axis_index, zero_const).output(0)
1466+
dim_size = ov_opset.squeeze(dim_size, zero_const).output(0)
1467+
1468+
zero_scalar = ov_opset.constant(0, indices.get_element_type()).output(0)
1469+
is_neg = ov_opset.less(indices, zero_scalar).output(0)
1470+
dim_size_cast = ov_opset.convert(
1471+
dim_size, indices.get_element_type()
1472+
).output(0)
1473+
adjusted_indices = ov_opset.add(indices, dim_size_cast).output(0)
1474+
indices = ov_opset.select(is_neg, adjusted_indices, indices).output(0)
1475+
1476+
indices = ov_opset.convert(indices, Type.i32).output(0)
1477+
1478+
one_const = ov_opset.constant(1, dtype=Type.i32).output(0)
1479+
1480+
# Create modified shapes with axis dimension set to 1
1481+
x_shape_modified = []
1482+
indices_shape_modified = []
1483+
1484+
for i in range(x_rank):
1485+
dim_index = ov_opset.constant([i], dtype=Type.i32).output(0)
1486+
if i == axis:
1487+
x_shape_modified.append(
1488+
ov_opset.unsqueeze(one_const, zero_const).output(0)
1489+
)
1490+
indices_shape_modified.append(
1491+
ov_opset.unsqueeze(one_const, zero_const).output(0)
1492+
)
1493+
else:
1494+
x_dim = ov_opset.gather(x_shape, dim_index, zero_const).output(0)
1495+
indices_dim = ov_opset.gather(
1496+
indices_shape, dim_index, zero_const
1497+
).output(0)
1498+
x_shape_modified.append(x_dim)
1499+
indices_shape_modified.append(indices_dim)
1500+
1501+
x_shape_mod = ov_opset.concat(x_shape_modified, axis=0).output(0)
1502+
indices_shape_mod = ov_opset.concat(indices_shape_modified, axis=0).output(
1503+
0
14471504
)
14481505

1506+
# Compute broadcast shape (maximum of each dimension)
1507+
broadcast_shape_parts = []
1508+
for i in range(x_rank):
1509+
dim_index = ov_opset.constant([i], dtype=Type.i32).output(0)
1510+
x_dim = ov_opset.gather(x_shape_mod, dim_index, zero_const).output(0)
1511+
indices_dim = ov_opset.gather(
1512+
indices_shape_mod, dim_index, zero_const
1513+
).output(0)
1514+
max_dim = ov_opset.maximum(x_dim, indices_dim).output(0)
1515+
broadcast_shape_parts.append(max_dim)
1516+
1517+
broadcast_shape = ov_opset.concat(broadcast_shape_parts, axis=0).output(0)
1518+
1519+
# Create target shapes: broadcast shape but with original axis dimensions
1520+
x_target_shape_parts = []
1521+
indices_target_shape_parts = []
1522+
1523+
for i in range(x_rank):
1524+
dim_index = ov_opset.constant([i], dtype=Type.i32).output(0)
1525+
if i == axis:
1526+
x_orig_dim = ov_opset.gather(x_shape, dim_index, zero_const).output(
1527+
0
1528+
)
1529+
indices_orig_dim = ov_opset.gather(
1530+
indices_shape, dim_index, zero_const
1531+
).output(0)
1532+
x_target_shape_parts.append(x_orig_dim)
1533+
indices_target_shape_parts.append(indices_orig_dim)
1534+
else:
1535+
broadcast_dim = ov_opset.gather(
1536+
broadcast_shape, dim_index, zero_const
1537+
).output(0)
1538+
x_target_shape_parts.append(broadcast_dim)
1539+
indices_target_shape_parts.append(broadcast_dim)
1540+
1541+
x_target_shape = ov_opset.concat(x_target_shape_parts, axis=0).output(0)
1542+
indices_target_shape = ov_opset.concat(
1543+
indices_target_shape_parts, axis=0
1544+
).output(0)
1545+
1546+
# Broadcast to target shapes
1547+
x_broadcasted = ov_opset.broadcast(x, x_target_shape).output(0)
1548+
indices_broadcasted = ov_opset.broadcast(
1549+
indices, indices_target_shape
1550+
).output(0)
1551+
1552+
# Use gather_elements for element-wise selection
1553+
result = ov_opset.gather_elements(
1554+
x_broadcasted, indices_broadcasted, axis
1555+
).output(0)
1556+
return OpenVINOKerasTensor(result)
1557+
14491558

14501559
def tan(x):
14511560
x = get_ov_output(x)

0 commit comments

Comments
 (0)