11
11
from onnx import helper , onnx_pb , numpy_helper
12
12
from tensorflow .core .framework import types_pb2 , tensor_pb2
13
13
from tensorflow .python .framework import tensor_util
14
+ import tensorflow as tf
14
15
from tf2onnx .tflite .TensorType import TensorType as TFLiteTensorType
15
16
from tf2onnx .tflite .Model import Model
16
17
from tf2onnx .flexbuffers import read_flexbuffer
@@ -138,7 +139,19 @@ def read_tflite_model(tflite_path):
138
139
code = op_code .CustomCode ().decode ()
139
140
opcodes_map [i ] = code
140
141
tflite_graphs = [model .Subgraphs (i ) for i in range (model .SubgraphsLength ())]
141
- return tflite_graphs , opcodes_map , model
142
+ # Shapes stored in tflite models are not always reliable so we get them from the interpreter if possible.
143
+ interpreter = tf .lite .Interpreter (tflite_path )
144
+ interpreter .allocate_tensors ()
145
+ tensor_cnt = model .Subgraphs (0 ).TensorsLength ()
146
+ tensor_shapes = {}
147
+ for i in range (tensor_cnt ):
148
+ name = model .Subgraphs (0 ).Tensors (i ).Name ().decode ()
149
+ details = interpreter ._get_tensor_details (i ) # pylint: disable=protected-access
150
+ if "shape_signature" in details :
151
+ tensor_shapes [name ] = details ["shape_signature" ].tolist ()
152
+ elif "shape" in details :
153
+ tensor_shapes [name ] = details ["shape" ].tolist ()
154
+ return tflite_graphs , opcodes_map , model , tensor_shapes
142
155
143
156
144
157
def get_quantization_attr (quant_params ):
@@ -153,7 +166,7 @@ def get_quantization_attr(quant_params):
153
166
return attr
154
167
155
168
156
- def parse_tflite_graph (tflite_g , opcodes_map , model , input_prefix = '' ):
169
+ def parse_tflite_graph (tflite_g , opcodes_map , model , input_prefix = '' , tensor_shapes_override = None ):
157
170
"""
158
171
Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_".
159
172
Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs.
@@ -165,6 +178,8 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
165
178
output_shapes = {}
166
179
dtypes = {}
167
180
tensor_names = {}
181
+ if tensor_shapes_override is None :
182
+ tensor_shapes_override = {}
168
183
# Map tensor name to tflite Tensor object so we can fetch quantization info as needed
169
184
name_to_tensor = {}
170
185
# If a node takes a quantized tensor as input, we must add a dequantize op after it.
@@ -183,7 +198,9 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
183
198
tensor_names [i ] = name
184
199
name_to_tensor [name ] = tensor
185
200
186
- if tensor .ShapeIsNone ():
201
+ if name in tensor_shapes_override :
202
+ output_shapes [name ] = tensor_shapes_override [name ]
203
+ elif tensor .ShapeIsNone ():
187
204
output_shapes [name ] = None
188
205
elif tensor .ShapeSignatureIsNone ():
189
206
# The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead.
0 commit comments