Skip to content

Commit d8d14fd

Browse files
authored
[OpenVINO Backend] Support numpy exp and expand_dims (#21006)
Signed-off-by: Kazantsev, Roman <[email protected]>
1 parent 2ebe3d6 commit d8d14fd

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ NumpyDtypeTest::test_dot
2626
NumpyDtypeTest::test_einsum
2727
NumpyDtypeTest::test_empty
2828
NumpyDtypeTest::test_exp2
29-
NumpyDtypeTest::test_exp
3029
NumpyDtypeTest::test_expm1
3130
NumpyDtypeTest::test_eye
3231
NumpyDtypeTest::test_flip
@@ -102,8 +101,8 @@ NumpyOneInputOpsCorrectnessTest::test_diag
102101
NumpyOneInputOpsCorrectnessTest::test_diagonal
103102
NumpyOneInputOpsCorrectnessTest::test_diff
104103
NumpyOneInputOpsCorrectnessTest::test_dot
105-
NumpyOneInputOpsCorrectnessTest::test_exp
106-
NumpyOneInputOpsCorrectnessTest::test_expand_dims
104+
NumpyOneInputOpsCorrectnessTest::test_exp2
105+
NumpyOneInputOpsCorrectnessTest::test_expm1
107106
NumpyOneInputOpsCorrectnessTest::test_flip
108107
NumpyOneInputOpsCorrectnessTest::test_floor_divide
109108
NumpyOneInputOpsCorrectnessTest::test_hstack

keras/src/backend/openvino/numpy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,17 @@ def equal(x1, x2):
533533

534534
def exp(x):
535535
x = get_ov_output(x)
536+
x_type = x.get_element_type()
537+
if x_type.is_integral():
538+
ov_type = OPENVINO_DTYPES[config.floatx()]
539+
x = ov_opset.convert(x, ov_type)
536540
return OpenVINOKerasTensor(ov_opset.exp(x).output(0))
537541

538542

539543
def expand_dims(x, axis):
540-
if isinstance(x, OpenVINOKerasTensor):
541-
x = x.output
542-
else:
543-
assert False
544+
x = get_ov_output(x)
545+
if isinstance(axis, tuple):
546+
axis = list(axis)
544547
axis = ov_opset.constant(axis, Type.i32).output(0)
545548
return OpenVINOKerasTensor(ov_opset.unsqueeze(x, axis).output(0))
546549

0 commit comments

Comments
 (0)