8
8
9
9
from backend_test_base import Tf2OnnxBackendTestBase
10
10
from common import unittest_main , check_tf_min_version , check_tf_max_version , \
11
- check_onnxruntime_min_version , check_tfjs_max_version
11
+ check_onnxruntime_min_version , check_tfjs_max_version , skip_tflite
12
12
from tf2onnx .tf_loader import is_tf2
13
13
14
14
@@ -302,6 +302,21 @@ def func(i):
302
302
output_names_with_port = ["output:0" ]
303
303
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
304
304
305
+ @skip_tflite ("shape inference fails with tflite" )
306
+ def test_while_loop_cond_subgraphs (self ):
307
+ # test for while_loop with subgraphs in cond
308
+ def func (x ):
309
+ x_dim = tf .shape (x )[0 ]
310
+ r = tf .cast (tf .zeros (1 ), x .dtype )
311
+ for i in range (tf .constant (10 )):
312
+ if i == x_dim :
313
+ break
314
+ r += x [i ]
315
+ return tf .identity (r , name = "output" )
316
+ input_names_with_port = ["input_1:0" ]
317
+ feed_dict = {"input_1:0" : np .arange (0 , 15 , dtype = np .int32 )}
318
+ output_names_with_port = ["output:0" ]
319
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port )
305
320
306
321
if __name__ == '__main__' :
307
- unittest_main ()
322
+ unittest_main ()
0 commit comments