Skip to content

Commit 9e72c33

Browse files
Improve reliability of reading tflite shapes
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent ed9b987 commit 9e72c33

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ class TFlSoftmaxOp:
222222
@classmethod
223223
def to_tf(cls, ctx, node, **kwargs):
224224
beta = node.get_attr_value("beta")
225-
beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32))
226-
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name))
227-
ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]])
225+
if beta != 1:
226+
beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32))
227+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name))
228+
ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]])

tf2onnx/tflite_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from onnx import helper, onnx_pb, numpy_helper
1212
from tensorflow.core.framework import types_pb2, tensor_pb2
1313
from tensorflow.python.framework import tensor_util
14+
import tensorflow as tf
1415
from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
1516
from tf2onnx.tflite.Model import Model
1617
from tf2onnx.flexbuffers import read_flexbuffer
@@ -138,7 +139,19 @@ def read_tflite_model(tflite_path):
138139
code = op_code.CustomCode().decode()
139140
opcodes_map[i] = code
140141
tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())]
141-
return tflite_graphs, opcodes_map, model
142+
# Shapes stored in tflite models are not always reliable so we get them from the interpreter if possible.
143+
interpreter = tf.lite.Interpreter(tflite_path)
144+
interpreter.allocate_tensors()
145+
tensor_cnt = model.Subgraphs(0).TensorsLength()
146+
tensor_shapes = {}
147+
for i in range(tensor_cnt):
148+
name = model.Subgraphs(0).Tensors(i).Name().decode()
149+
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
150+
if "shape_signature" in details:
151+
tensor_shapes[name] = details["shape_signature"].tolist()
152+
elif "shape" in details:
153+
tensor_shapes[name] = details["shape"].tolist()
154+
return tflite_graphs, opcodes_map, model, tensor_shapes
142155

143156

144157
def get_quantization_attr(quant_params):
@@ -153,7 +166,7 @@ def get_quantization_attr(quant_params):
153166
return attr
154167

155168

156-
def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
169+
def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix='', tensor_shapes_override=None):
157170
"""
158171
Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_".
159172
Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs.
@@ -165,6 +178,8 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
165178
output_shapes = {}
166179
dtypes = {}
167180
tensor_names = {}
181+
if tensor_shapes_override is None:
182+
tensor_shapes_override = {}
168183
# Map tensor name to tflite Tensor object so we can fetch quantization info as needed
169184
name_to_tensor = {}
170185
# If a node takes a quantized tensor as input, we must add a dequantize op after it.
@@ -183,7 +198,9 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
183198
tensor_names[i] = name
184199
name_to_tensor[name] = tensor
185200

186-
if tensor.ShapeIsNone():
201+
if name in tensor_shapes_override:
202+
output_shapes[name] = tensor_shapes_override[name]
203+
elif tensor.ShapeIsNone():
187204
output_shapes[name] = None
188205
elif tensor.ShapeSignatureIsNone():
189206
# The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead.

tf2onnx/tfonnx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,17 @@ def rename_tensors_in_nodes(onnx_nodes):
461461
n.output[:] = rename_tensors_in_list(n.output)
462462

463463
if tflite_path is not None:
464-
tflite_graphs, opcodes, model = read_tflite_model(tflite_path)
464+
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
465465
main_g = None
466466
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
467467
for i in reversed(range(len(tflite_graphs))):
468468
tfl_graph = tflite_graphs[i]
469469
prefix = '' if i == 0 else tfl_graph.Name().decode() + '_'
470+
tensor_shapes_from_interpreter = None
471+
if i == 0:
472+
tensor_shapes_from_interpreter = tensor_shapes
470473
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
471-
parse_tflite_graph(tfl_graph, opcodes, model, prefix)
474+
parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
472475
g_inputs = f_inputs
473476
g_outputs = f_outputs
474477
if i == 0:

0 commit comments

Comments
 (0)