Skip to content

Commit d791944

Browse files
committed
fix unit test errors
Signed-off-by: Francesco Salvetti <[email protected]>
1 parent 41acb39 commit d791944

File tree

1 file changed

+115
-71
lines changed

1 file changed

+115
-71
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 115 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -562,155 +562,199 @@ class CumProd:
562562
@classmethod
563563
def version_10(cls, ctx, node, **kwargs):
564564
# opset 10 required for Slice to support starts/ends/axes/steps as inputs
565-
axis_node = ctx.make_node("Cast", inputs=[node.input[1]], attr={"to": onnx_pb.TensorProto.INT64},
565+
axis_node = node.inputs[1]
566+
is_axis_const = axis_node.is_const()
567+
if is_axis_const: # we can compute axis value right now
568+
axis = axis_node.get_tensor_value()
569+
axis_node = ctx.make_const(utils.make_name("axis"), np.array([axis], dtype=np.int64))
570+
else:
571+
axis_node = ctx.make_node("Cast", inputs=[axis_node.output[0]], attr={"to": onnx_pb.TensorProto.INT64},
566572
op_name_scope=node.name, outputs=[utils.make_name("axis")])
567-
axis = GraphBuilder(ctx).make_unsqueeze({'data': axis_node.output[0], 'axes': [0]})
573+
axis_node = GraphBuilder(ctx).make_unsqueeze({'data': axis_node.output[0], 'axes': [0]}, return_node=True)
574+
axis = axis_node.output[0]
575+
576+
input_rank = len(ctx.get_shape(node.input[0]))
568577
cond_true_node = ctx.make_const(utils.make_name("cond_in"), np.ones((), dtype=bool))
569-
input_shape_node = ctx.make_node("Shape", inputs=[node.input[0]], op_name_scope=node.name, outputs=[utils.make_name("input_shape")])
570-
input_rank_node = ctx.make_node("Shape", inputs=[input_shape_node.output[0]], op_name_scope=node.name, outputs=[utils.make_name("input_rank")])
578+
input_shape_node = ctx.make_node("Shape", inputs=[node.input[0]], op_name_scope=node.name,
579+
outputs=[utils.make_name("input_shape")])
571580
axis_length_node = ctx.make_node("Gather", inputs=[input_shape_node.output[0], node.input[1]],
572581
op_name_scope=node.name, outputs=[utils.make_name("axis_length")])
573582
one_node = ctx.make_const(utils.make_name("one"), np.array([1], "int64"))
574583
axis_length_plus_one_node = ctx.make_node("Add", inputs=[axis_length_node.output[0], one_node.output[0]],
575-
op_name_scope=node.name, outputs=[utils.make_name("axis_length_plus_one")])
584+
op_name_scope=node.name,
585+
outputs=[utils.make_name("axis_length_plus_one")])
576586
num_iter_node = ctx.make_node("Sub", inputs=[axis_length_node.output[0], one_node.output[0]],
577587
op_name_scope=node.name, outputs=[utils.make_name("num_iter")])
578-
588+
579589
if node.get_attr_value("exclusive"): # one iter less, crop the input, then pad the output
580590
num_iter_node = ctx.make_node("Sub", inputs=[num_iter_node.output[0], one_node.output[0]],
581591
op_name_scope=node.name, outputs=[utils.make_name("num_iter")])
582592
zero_node = ctx.make_const(utils.make_name("zero"), np.array([0], "int64"))
583593
if node.get_attr_value("reverse"):
584-
pad_tensors = [zero_node.output[0], one_node.output[0]]
594+
pad_axis = [0,1]
585595
start_slice = one_node.output[0]
586596
end_slice = axis_length_plus_one_node.output[0]
587597
else:
588598
minus_one_node = ctx.make_const(utils.make_name("minus_one"), np.array([-1], "int64"))
589-
pad_tensors = [one_node.output[0], zero_node.output[0]]
599+
pad_axis = [1,0]
590600
start_slice = zero_node.output[0]
591601
end_slice = minus_one_node.output[0]
592-
pads_node = cls.get_pads_node(ctx, pad_tensors, axis, input_rank_node.output[0], node.name)
602+
pads_node = cls.get_pads_node(ctx, pad_axis, axis, input_rank, node.name)
593603
slice_shape = [-1] * len(ctx.get_shape(node.input[0]))
594-
inputs_node = ctx.make_node("Slice", inputs=[node.input[0], start_slice, end_slice, axis],
604+
inputs_node = ctx.make_node("Slice", inputs=[node.input[0], start_slice, end_slice, axis_node.output[0]],
595605
op_name_scope=node.name, outputs=[utils.make_name("slice")],
596606
shapes=[slice_shape], dtypes=[ctx.get_dtype(node.input[0])])
597607
inputs = inputs_node.output[0]
598608
else:
599609
inputs = node.input[0]
600-
601-
loop_graph = cls.make_loop_graph(ctx, node, inputs)
610+
611+
loop_graph = cls.make_loop_graph(ctx, node, inputs, input_rank, axis)
602612
loop_graph.parent_graph = ctx
603613

604-
loop_inputs = [num_iter_node.output[0], cond_true_node.output[0], inputs, input_rank_node.output[0], axis, axis_length_plus_one_node.output[0], inputs]
605-
loop_outputs = [utils.make_name("loop_inputs_out"), utils.make_name("loop_rank_out"), utils.make_name("loop_axis_out"),
606-
utils.make_name("loop_axis_length_plus_one_out"), utils.make_name("loop_accumulator_out")]
614+
loop_inputs = [num_iter_node.output[0], cond_true_node.output[0], inputs,
615+
axis_length_plus_one_node.output[0], inputs]
616+
loop_outputs = [utils.make_name("loop_inputs_out"), utils.make_name("loop_axis_length_plus_one_out"),
617+
utils.make_name("loop_accumulator_out")]
618+
if not is_axis_const: # axis is a tensor, we neeed to feed it to the loop graph
619+
loop_inputs.append(axis)
620+
loop_outputs.append(utils.make_name("loop_axis_out"))
607621
loop_outputs_shapes = [loop_graph.get_shape(o) for o in loop_graph.outputs[1:]]
608622
loop_outputs_dtypes = [loop_graph.get_dtype(o) for o in loop_graph.outputs[1:]]
609-
623+
610624
loop_node = ctx.make_node("Loop", inputs=loop_inputs, branches={"body": loop_graph}, outputs=loop_outputs,
611625
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes, op_name_scope=node.name)
612-
626+
613627
if node.get_attr_value("exclusive"): # pad the output
614-
if ctx.get_dtype(loop_node.output[-1]) != ctx.get_dtype(one_node.output[0]):
615-
pad_const_node = ctx.make_node("Cast", inputs=[one_node.output[0]], attr={"to": ctx.get_dtype(loop_node.output[-1])},
628+
if ctx.get_dtype(loop_node.output[2]) != ctx.get_dtype(one_node.output[0]):
629+
pad_const_node = ctx.make_node("Cast", inputs=[one_node.output[0]],
630+
attr={"to": ctx.get_dtype(loop_node.output[2])},
616631
op_name_scope=node.name, outputs=[utils.make_name("pad_const")])
617632
else:
618633
pad_const_node = one_node
619-
output_node = ctx.make_node("Pad", inputs=[loop_node.output[-1], pads_node.output[0], pad_const_node.output[0]],
620-
op_name_scope=node.name, outputs=[utils.make_name("cumprod_out")])
634+
output_node = ctx.make_node("Pad", op_name_scope=node.name, outputs=[utils.make_name("cumprod_out")],
635+
inputs=[loop_node.output[2], pads_node.output[0], pad_const_node.output[0]])
621636
output = output_node.output[0]
622637
else:
623-
output = loop_node.output[-1]
638+
output = loop_node.output[2]
624639
output_node = ctx.make_node("Identity", inputs=[output], outputs=[utils.make_name("cumprod_out")],
625640
shapes=[ctx.get_shape(node.input[0])], dtypes=[ctx.get_dtype(node.input[0])])
626641
ctx.insert_node_on_output(output_node, node.output[0])
627642
ctx.remove_node(node.name)
628643

629644
@classmethod
630-
def make_loop_graph(cls, ctx, node, inputs_tensor):
645+
def make_loop_graph(cls, ctx, node, inputs_tensor, input_rank, axis):
646+
inputs_tensor_shape = ctx.get_shape(inputs_tensor)
647+
inputs_tensor_dtype = ctx.get_dtype(inputs_tensor)
648+
631649
graph = ctx.create_new_graph_with_same_config()
632650
graph.add_graph_input(utils.make_name("iteration_num"), onnx_pb.TensorProto.INT64, [])
633651
graph.add_graph_input(utils.make_name("condition_in"), onnx_pb.TensorProto.BOOL, [])
634-
graph.add_graph_input(utils.make_name("inputs"), ctx.get_dtype(node.input[0]), ctx.get_shape(inputs_tensor))
635-
graph.add_graph_input(utils.make_name("rank"), onnx_pb.TensorProto.INT64, [1])
636-
graph.add_graph_input(utils.make_name("axis"), onnx_pb.TensorProto.INT64, [1])
652+
graph.add_graph_input(utils.make_name("inputs"), inputs_tensor_dtype, inputs_tensor_shape)
637653
graph.add_graph_input(utils.make_name("axis_length_plus_one"), onnx_pb.TensorProto.INT64, [1])
638-
graph.add_graph_input(utils.make_name("accumulator"), ctx.get_dtype(node.input[0]), ctx.get_shape(inputs_tensor))
654+
graph.add_graph_input(utils.make_name("accumulator"), inputs_tensor_dtype, inputs_tensor_shape)
655+
if not isinstance(axis, int): # axis is a tensor, we need to feed it to the loop graph
656+
graph.add_graph_input(utils.make_name("axis"), onnx_pb.TensorProto.INT64, [1])
657+
axis = graph.input_names[-1]
658+
axis_node = graph.get_node_by_output(axis)
659+
else:
660+
axis_node = graph.make_const(utils.make_name("axis"), np.array([axis], "int64"))
639661

640662
# main loop graph
641663
loop_name = node.name + "/loop"
642664
iter_num = GraphBuilder(graph).make_unsqueeze({'data': graph.input_names[0], 'axes': [0]})
643665
one_node = graph.make_const(utils.make_name("one"), np.array(1, "int64"))
644666
zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64"))
645-
667+
646668
add_node = graph.make_node("Add", inputs=[iter_num, one_node.output[0]],
647669
outputs=[utils.make_name("add")], op_name_scope=loop_name)
648670

649671
if node.get_attr_value("reverse"):
650-
pad_tensors = [zero_node.output[0], add_node.output[0]]
672+
pad_axis = [zero_node.output[0], add_node.output[0]]
651673
start_slice = add_node.output[0]
652-
end_slice = graph.input_names[5]
674+
end_slice = graph.input_names[3]
653675
else:
654-
neg_node = graph.make_node("Neg", inputs=[add_node.output[0]], outputs=[utils.make_name("neg")], op_name_scope=loop_name)
655-
pad_tensors = [add_node.output[0], zero_node.output[0]]
676+
neg_node = graph.make_node("Neg", inputs=[add_node.output[0]],
677+
outputs=[utils.make_name("neg")], op_name_scope=loop_name)
678+
pad_axis = [add_node.output[0], zero_node.output[0]]
656679
start_slice = zero_node.output[0]
657680
end_slice = neg_node.output[0]
658-
659-
pads_node = cls.get_pads_node(graph, pad_tensors, graph.input_names[4], graph.input_names[3], loop_name)
660-
slice_node = graph.make_node("Slice", inputs=[graph.input_names[2], start_slice, end_slice, graph.input_names[4]],
661-
op_name_scope=loop_name, outputs=[utils.make_name("slice")])
681+
682+
pads_node = cls.get_pads_node(graph, pad_axis, axis, input_rank, is_pad_axis_const=False, base_name=loop_name)
683+
slice_node = graph.make_node("Slice", op_name_scope=loop_name, outputs=[utils.make_name("slice")],
684+
inputs=[graph.input_names[2], start_slice, end_slice, axis_node.output[0]])
662685
if graph.get_dtype(slice_node.output[0]) != graph.get_dtype(one_node.output[0]):
663-
pad_const_node = graph.make_node("Cast", inputs=[one_node.output[0]], attr={"to": graph.get_dtype(slice_node.output[0])},
664-
op_name_scope=loop_name, outputs=[utils.make_name("pad_const")])
686+
pad_const_node = graph.make_node("Cast", inputs=[one_node.output[0]],
687+
attr={"to": graph.get_dtype(slice_node.output[0])},
688+
op_name_scope=loop_name, outputs=[utils.make_name("pad_const")])
665689
else:
666690
pad_const_node = one_node
667691
pad_node = graph.make_node("Pad", inputs=[slice_node.output[0], pads_node.output[0], pad_const_node.output[0]],
668-
op_name_scope=loop_name, outputs=[utils.make_name("pad")])
669-
mul_node = graph.make_node("Mul", inputs=[graph.input_names[-1], pad_node.output[0]],
670-
op_name_scope=loop_name, outputs=[utils.make_name("mul")])
671-
692+
op_name_scope=loop_name, outputs=[utils.make_name("pad")])
693+
mul_node = graph.make_node("Mul", inputs=[graph.input_names[4], pad_node.output[0]],
694+
op_name_scope=loop_name, outputs=[utils.make_name("mul")],
695+
shapes=[inputs_tensor_shape], dtypes=[inputs_tensor_dtype])
696+
672697
# manage loop outputs
673698
output_cond_node = graph.make_node("Identity", inputs=[graph.input_names[1]], op_name_scope=loop_name,
674699
outputs=[utils.make_name("condition_out")])
675700
output_inp_node = graph.make_node("Identity", inputs=[graph.input_names[2]], op_name_scope=loop_name,
676-
outputs=[utils.make_name("inputs_out")])
677-
output_rank_node = graph.make_node("Identity", inputs=[graph.input_names[3]], op_name_scope=loop_name,
678-
outputs=[utils.make_name("rank_out")])
679-
output_axis_node = graph.make_node("Identity", inputs=[graph.input_names[4]], op_name_scope=loop_name,
680-
outputs=[utils.make_name("axis_out")])
681-
output_axis_length_plus_one_node = graph.make_node("Identity", inputs=[graph.input_names[5]],
682-
op_name_scope=loop_name, outputs=[utils.make_name("axis_length_plus_one_out")])
701+
outputs=[utils.make_name("inputs_out")])
702+
output_axis_length_plus_one_node = graph.make_node("Identity", inputs=[graph.input_names[3]],
703+
op_name_scope=loop_name,
704+
outputs=[utils.make_name("axis_length_plus_one_out")])
683705
output_acc_node = graph.make_node("Identity", inputs=[mul_node.output[0]], op_name_scope=loop_name,
684-
outputs=[utils.make_name("accumulator_out")],
685-
shapes=[ctx.get_shape(node.input[0])], dtypes=[ctx.get_dtype(node.input[0])])
706+
outputs=[utils.make_name("accumulator_out")])
686707

687708
graph.add_graph_output(output_cond_node.output[0]) # 1 condition output
688709
graph.add_graph_output(output_inp_node.output[0]) # N loop carried dependencies outputs
689-
graph.add_graph_output(output_rank_node.output[0]) # N loop carried dependencies outputs
690-
graph.add_graph_output(output_axis_node.output[0]) # N loop carried dependencies outputs
691710
graph.add_graph_output(output_axis_length_plus_one_node.output[0]) # N loop carried dependencies outputs
692711
graph.add_graph_output(output_acc_node.output[0]) # N loop carried dependencies outputs
712+
713+
if not isinstance(axis, int): # axis is a tensor, we need to feed it to the loop graph
714+
output_axis_node = graph.make_node("Identity", inputs=[axis], op_name_scope=loop_name,
715+
outputs=[utils.make_name("axis_out")])
716+
graph.add_graph_output(output_axis_node.output[0]) # N loop carried dependencies outputs
693717
return graph
694718

695719
@classmethod
696-
def get_pads_node(cls, graph, pad_tensors, axis_tensor, rank_tensor, base_name=""):
697-
zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64"))
698-
one_node = graph.make_const(utils.make_name("zero"), np.array([1], "int64"))
699-
700-
post_repeat_node = graph.make_node("Sub", inputs=[rank_tensor, axis_tensor],
701-
outputs=[utils.make_name("post_repeat")], op_name_scope=base_name)
702-
post_repeat_node = graph.make_node("Sub", inputs=[post_repeat_node.output[0], one_node.output[0]],
703-
outputs=[utils.make_name("post_repeat")], op_name_scope=base_name)
704-
705-
pre_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], axis_tensor],
706-
attr = {"axis": 0}, outputs=[utils.make_name("pre_pad")], op_name_scope=base_name)
707-
708-
post_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], post_repeat_node.output[0]],
709-
attr = {"axis": 0}, outputs=[utils.make_name("post_pad")], op_name_scope=base_name)
710-
711-
pads_node = graph.make_node("Concat", attr = {"axis": 0}, outputs=[utils.make_name("pads")], op_name_scope=base_name,
712-
inputs=[pre_pad_node.output[0], pad_tensors[0], post_pad_node.output[0],
713-
pre_pad_node.output[0], pad_tensors[1], post_pad_node.output[0]])
720+
def get_pads_node(cls, graph, pad_axis, axis, rank, is_pad_axis_const=True, base_name=""):
721+
if isinstance(axis, int): # axis, is a const, we directly compute padding values
722+
pre_pad = np.zeros(axis, "int64")
723+
post_pad = np.zeros(rank - axis - 1, "int64")
724+
if is_pad_axis_const: # pylint: disable=R1705
725+
pads = np.concatenate([pre_pad, pad_axis[0:1], post_pad,
726+
pre_pad, pad_axis[1:2], post_pad])
727+
pads_node = graph.make_const(utils.make_name("pads"), pads)
728+
return pads_node
729+
else:
730+
pre_pad_node = graph.make_const(utils.make_name("pre_pad"), pre_pad)
731+
post_pad_node = graph.make_const(utils.make_name("post_pad"), post_pad)
732+
733+
else: # axis is a tensor, we need to compute padding values at runtime
734+
if is_pad_axis_const:
735+
pad_axis = [graph.make_const(utils.make_name("pad"),
736+
np.array([pad], "int64")).output[0] for pad in pad_axis]
737+
738+
rank_tensor = graph.make_const(utils.make_name("rank"), np.array([rank], "int64")).output[0]
739+
zero_node = graph.make_const(utils.make_name("zero"), np.array([0], "int64"))
740+
one_node = graph.make_const(utils.make_name("zero"), np.array([1], "int64"))
741+
742+
post_repeat_node = graph.make_node("Sub", inputs=[rank_tensor, axis],
743+
outputs=[utils.make_name("post_repeat")], op_name_scope=base_name)
744+
post_repeat_node = graph.make_node("Sub", inputs=[post_repeat_node.output[0], one_node.output[0]],
745+
outputs=[utils.make_name("post_repeat")], op_name_scope=base_name)
746+
747+
pre_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], axis], op_name_scope=base_name,
748+
attr={"axis": 0}, outputs=[utils.make_name("pre_pad")])
749+
750+
post_pad_node = graph.make_node("Tile", inputs=[zero_node.output[0], post_repeat_node.output[0]],
751+
attr={"axis": 0}, outputs=[utils.make_name("post_pad")],
752+
op_name_scope=base_name)
753+
754+
pads_node = graph.make_node("Concat", attr={"axis": 0}, outputs=[utils.make_name("pads")],
755+
op_name_scope=base_name,
756+
inputs=[pre_pad_node.output[0], pad_axis[0], post_pad_node.output[0],
757+
pre_pad_node.output[0], pad_axis[1], post_pad_node.output[0]])
714758
return pads_node
715759

716760

0 commit comments

Comments
 (0)