Skip to content

Commit c0f6f4b

Browse files
committed
fix wrong shapes in loop body inputs if shape invariances are set in TF
Signed-off-by: Salvetti, Francesco <[email protected]>
1 parent 25c977c commit c0f6f4b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
571571
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
572572
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]
573573

574-
for p, c in zip(loop_node.input, func_inputs):
574+
# we should use outputs shape, not inputs, since there may be shape invariants
575+
for p, c in zip(loop_node.output, func_inputs[2:]):
575576
g.copy_shape(p, c)
576577

577578
for i, node in enumerate(g.inputs):

0 commit comments

Comments
 (0)