Skip to content

Commit e7f39ed

Browse files
q-ycong-pYu Conghwangdeyu
authored
Skip existing const initializer node as input in _parse_graph_input (#2000)
If the Graph object being constructed already contains a node of the name of a Const node in orginal graph, do not add as input. Signed-off-by: Yu Cong <[email protected]> Co-authored-by: Yu Cong <[email protected]> Co-authored-by: Deyu Huang <[email protected]>
1 parent e896723 commit e7f39ed

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

tests/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1927,7 +1927,7 @@ def test_duplicated_duplicated_constant_and_initializer(self):
19271927

19281928
model_proto = self.make_model(graph, producer_name="onnx-tests")
19291929
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, op_type="Constant", remaining_op_num=0,
1930-
graph_validator=lambda g: self._check_initializer_num(g, 2))
1930+
graph_validator=lambda g: self._check_initializer_num(g, 1))
19311931

19321932
def test_duplicated_node_is_graph_output(self):
19331933
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])

tf2onnx/graph.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,9 +1791,11 @@ def _parse_graph_input(g, graph_proto, const_node_names):
17911791
# because for subgraphs, the input orders matter.
17921792
for graph_input in graph_proto.input:
17931793
name = graph_input.name
1794-
shape = shapes[name]
1795-
dtype = dtypes[name]
1796-
if name not in const_node_names:
1797-
g.add_graph_input(name, dtype, shape)
1798-
else:
1799-
g.add_graph_input_with_default(name, g.get_node_by_name(name), dtype, shape)
1794+
const_initializer_node = g.get_node_by_output_in_current_graph(name)
1795+
if const_initializer_node is None: # is actual input rather than initializer
1796+
shape = shapes[name]
1797+
dtype = dtypes[name]
1798+
if name not in const_node_names:
1799+
g.add_graph_input(name, dtype, shape)
1800+
else:
1801+
g.add_graph_input_with_default(name, g.get_node_by_name(name), dtype, shape)

0 commit comments

Comments
 (0)