Skip to content

Commit 34c8512

Browse files
authored
Update the way to get opset version for running tests. (#2033)
Signed-off-by: Jay Zhang <[email protected]>
1 parent 42d7222 commit 34c8512

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import pytest
44
import numpy as np
5-
from tf2onnx.keras2onnx_api import get_maximum_opset_supported
65
from mock_keras2onnx.proto.tfcompat import is_tf2, tensorflow as tf
76
from mock_keras2onnx.proto import (keras, is_tf_keras,
87
is_tensorflow_older_than, is_tensorflow_later_than,
98
is_keras_older_than, is_keras_later_than, python_keras_is_deprecated)
10-
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional, convert_keras_for_test as convert_keras
9+
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional, convert_keras_for_test as convert_keras, get_max_opset_supported_for_test as get_maximum_opset_supported
1110

1211
K = keras.backend
1312
Activation = keras.layers.Activation

tests/keras2onnx_unit_tests/test_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mock_keras2onnx.proto import keras, is_keras_older_than
1010
from mock_keras2onnx.proto.tfcompat import is_tf2
1111
from packaging.version import Version
12-
from tf2onnx.keras2onnx_api import convert_keras
12+
from tf2onnx.keras2onnx_api import convert_keras, get_maximum_opset_supported
1313
import time
1414
import json
1515
import urllib
@@ -323,10 +323,13 @@ def get_max_opset_supported_by_ort():
323323
return None
324324

325325

326+
def get_max_opset_supported_for_test():
327+
return min(get_max_opset_supported_by_ort(), get_maximum_opset_supported())
328+
329+
326330
def convert_keras_for_test(model, name=None, target_opset=None, **kwargs):
327331
if target_opset is None:
328332
target_opset = get_max_opset_supported_by_ort()
329333

330334
print("Trying to run test with opset version: {}".format(target_opset))
331-
332-
return convert_keras(model=model, name=name, target_opset=target_opset, **kwargs)
335+
return convert_keras(model=model, name=name, target_opset=target_opset, **kwargs)

0 commit comments

Comments
 (0)