@@ -562,155 +562,199 @@ class CumProd:
562
562
@classmethod
563
563
def version_10 (cls , ctx , node , ** kwargs ):
564
564
# opset 10 required for Slice to support starts/ends/axes/steps as inputs
565
- axis_node = ctx .make_node ("Cast" , inputs = [node .input [1 ]], attr = {"to" : onnx_pb .TensorProto .INT64 },
565
+ axis_node = node .inputs [1 ]
566
+ is_axis_const = axis_node .is_const ()
567
+ if is_axis_const : # we can compute axis value right now
568
+ axis = axis_node .get_tensor_value ()
569
+ axis_node = ctx .make_const (utils .make_name ("axis" ), np .array ([axis ], dtype = np .int64 ))
570
+ else :
571
+ axis_node = ctx .make_node ("Cast" , inputs = [axis_node .output [0 ]], attr = {"to" : onnx_pb .TensorProto .INT64 },
566
572
op_name_scope = node .name , outputs = [utils .make_name ("axis" )])
567
- axis = GraphBuilder (ctx ).make_unsqueeze ({'data' : axis_node .output [0 ], 'axes' : [0 ]})
573
+ axis_node = GraphBuilder (ctx ).make_unsqueeze ({'data' : axis_node .output [0 ], 'axes' : [0 ]}, return_node = True )
574
+ axis = axis_node .output [0 ]
575
+
576
+ input_rank = len (ctx .get_shape (node .input [0 ]))
568
577
cond_true_node = ctx .make_const (utils .make_name ("cond_in" ), np .ones ((), dtype = bool ))
569
- input_shape_node = ctx .make_node ("Shape" , inputs = [node .input [0 ]], op_name_scope = node .name , outputs = [ utils . make_name ( "input_shape" )])
570
- input_rank_node = ctx . make_node ( "Shape" , inputs = [ input_shape_node . output [ 0 ]], op_name_scope = node . name , outputs = [utils .make_name ("input_rank " )])
578
+ input_shape_node = ctx .make_node ("Shape" , inputs = [node .input [0 ]], op_name_scope = node .name ,
579
+ outputs = [utils .make_name ("input_shape " )])
571
580
axis_length_node = ctx .make_node ("Gather" , inputs = [input_shape_node .output [0 ], node .input [1 ]],
572
581
op_name_scope = node .name , outputs = [utils .make_name ("axis_length" )])
573
582
one_node = ctx .make_const (utils .make_name ("one" ), np .array ([1 ], "int64" ))
574
583
axis_length_plus_one_node = ctx .make_node ("Add" , inputs = [axis_length_node .output [0 ], one_node .output [0 ]],
575
- op_name_scope = node .name , outputs = [utils .make_name ("axis_length_plus_one" )])
584
+ op_name_scope = node .name ,
585
+ outputs = [utils .make_name ("axis_length_plus_one" )])
576
586
num_iter_node = ctx .make_node ("Sub" , inputs = [axis_length_node .output [0 ], one_node .output [0 ]],
577
587
op_name_scope = node .name , outputs = [utils .make_name ("num_iter" )])
578
-
588
+
579
589
if node .get_attr_value ("exclusive" ): # one iter less, crop the input, then pad the output
580
590
num_iter_node = ctx .make_node ("Sub" , inputs = [num_iter_node .output [0 ], one_node .output [0 ]],
581
591
op_name_scope = node .name , outputs = [utils .make_name ("num_iter" )])
582
592
zero_node = ctx .make_const (utils .make_name ("zero" ), np .array ([0 ], "int64" ))
583
593
if node .get_attr_value ("reverse" ):
584
- pad_tensors = [zero_node . output [ 0 ], one_node . output [ 0 ] ]
594
+ pad_axis = [0 , 1 ]
585
595
start_slice = one_node .output [0 ]
586
596
end_slice = axis_length_plus_one_node .output [0 ]
587
597
else :
588
598
minus_one_node = ctx .make_const (utils .make_name ("minus_one" ), np .array ([- 1 ], "int64" ))
589
- pad_tensors = [one_node . output [ 0 ], zero_node . output [ 0 ] ]
599
+ pad_axis = [1 , 0 ]
590
600
start_slice = zero_node .output [0 ]
591
601
end_slice = minus_one_node .output [0 ]
592
- pads_node = cls .get_pads_node (ctx , pad_tensors , axis , input_rank_node . output [ 0 ] , node .name )
602
+ pads_node = cls .get_pads_node (ctx , pad_axis , axis , input_rank , node .name )
593
603
slice_shape = [- 1 ] * len (ctx .get_shape (node .input [0 ]))
594
- inputs_node = ctx .make_node ("Slice" , inputs = [node .input [0 ], start_slice , end_slice , axis ],
604
+ inputs_node = ctx .make_node ("Slice" , inputs = [node .input [0 ], start_slice , end_slice , axis_node . output [ 0 ] ],
595
605
op_name_scope = node .name , outputs = [utils .make_name ("slice" )],
596
606
shapes = [slice_shape ], dtypes = [ctx .get_dtype (node .input [0 ])])
597
607
inputs = inputs_node .output [0 ]
598
608
else :
599
609
inputs = node .input [0 ]
600
-
601
- loop_graph = cls .make_loop_graph (ctx , node , inputs )
610
+
611
+ loop_graph = cls .make_loop_graph (ctx , node , inputs , input_rank , axis )
602
612
loop_graph .parent_graph = ctx
603
613
604
- loop_inputs = [num_iter_node .output [0 ], cond_true_node .output [0 ], inputs , input_rank_node .output [0 ], axis , axis_length_plus_one_node .output [0 ], inputs ]
605
- loop_outputs = [utils .make_name ("loop_inputs_out" ), utils .make_name ("loop_rank_out" ), utils .make_name ("loop_axis_out" ),
606
- utils .make_name ("loop_axis_length_plus_one_out" ), utils .make_name ("loop_accumulator_out" )]
614
+ loop_inputs = [num_iter_node .output [0 ], cond_true_node .output [0 ], inputs ,
615
+ axis_length_plus_one_node .output [0 ], inputs ]
616
+ loop_outputs = [utils .make_name ("loop_inputs_out" ), utils .make_name ("loop_axis_length_plus_one_out" ),
617
+ utils .make_name ("loop_accumulator_out" )]
618
+ if not is_axis_const : # axis is a tensor, we neeed to feed it to the loop graph
619
+ loop_inputs .append (axis )
620
+ loop_outputs .append (utils .make_name ("loop_axis_out" ))
607
621
loop_outputs_shapes = [loop_graph .get_shape (o ) for o in loop_graph .outputs [1 :]]
608
622
loop_outputs_dtypes = [loop_graph .get_dtype (o ) for o in loop_graph .outputs [1 :]]
609
-
623
+
610
624
loop_node = ctx .make_node ("Loop" , inputs = loop_inputs , branches = {"body" : loop_graph }, outputs = loop_outputs ,
611
625
shapes = loop_outputs_shapes , dtypes = loop_outputs_dtypes , op_name_scope = node .name )
612
-
626
+
613
627
if node .get_attr_value ("exclusive" ): # pad the output
614
- if ctx .get_dtype (loop_node .output [- 1 ]) != ctx .get_dtype (one_node .output [0 ]):
615
- pad_const_node = ctx .make_node ("Cast" , inputs = [one_node .output [0 ]], attr = {"to" : ctx .get_dtype (loop_node .output [- 1 ])},
628
+ if ctx .get_dtype (loop_node .output [2 ]) != ctx .get_dtype (one_node .output [0 ]):
629
+ pad_const_node = ctx .make_node ("Cast" , inputs = [one_node .output [0 ]],
630
+ attr = {"to" : ctx .get_dtype (loop_node .output [2 ])},
616
631
op_name_scope = node .name , outputs = [utils .make_name ("pad_const" )])
617
632
else :
618
633
pad_const_node = one_node
619
- output_node = ctx .make_node ("Pad" , inputs = [ loop_node . output [ - 1 ], pads_node . output [ 0 ], pad_const_node . output [ 0 ] ],
620
- op_name_scope = node . name , outputs = [ utils . make_name ( "cumprod_out" ) ])
634
+ output_node = ctx .make_node ("Pad" , op_name_scope = node . name , outputs = [ utils . make_name ( "cumprod_out" ) ],
635
+ inputs = [ loop_node . output [ 2 ], pads_node . output [ 0 ], pad_const_node . output [ 0 ] ])
621
636
output = output_node .output [0 ]
622
637
else :
623
- output = loop_node .output [- 1 ]
638
+ output = loop_node .output [2 ]
624
639
output_node = ctx .make_node ("Identity" , inputs = [output ], outputs = [utils .make_name ("cumprod_out" )],
625
640
shapes = [ctx .get_shape (node .input [0 ])], dtypes = [ctx .get_dtype (node .input [0 ])])
626
641
ctx .insert_node_on_output (output_node , node .output [0 ])
627
642
ctx .remove_node (node .name )
628
643
629
644
@classmethod
630
- def make_loop_graph (cls , ctx , node , inputs_tensor ):
645
+ def make_loop_graph (cls , ctx , node , inputs_tensor , input_rank , axis ):
646
+ inputs_tensor_shape = ctx .get_shape (inputs_tensor )
647
+ inputs_tensor_dtype = ctx .get_dtype (inputs_tensor )
648
+
631
649
graph = ctx .create_new_graph_with_same_config ()
632
650
graph .add_graph_input (utils .make_name ("iteration_num" ), onnx_pb .TensorProto .INT64 , [])
633
651
graph .add_graph_input (utils .make_name ("condition_in" ), onnx_pb .TensorProto .BOOL , [])
634
- graph .add_graph_input (utils .make_name ("inputs" ), ctx .get_dtype (node .input [0 ]), ctx .get_shape (inputs_tensor ))
635
- graph .add_graph_input (utils .make_name ("rank" ), onnx_pb .TensorProto .INT64 , [1 ])
636
- graph .add_graph_input (utils .make_name ("axis" ), onnx_pb .TensorProto .INT64 , [1 ])
652
+ graph .add_graph_input (utils .make_name ("inputs" ), inputs_tensor_dtype , inputs_tensor_shape )
637
653
graph .add_graph_input (utils .make_name ("axis_length_plus_one" ), onnx_pb .TensorProto .INT64 , [1 ])
638
- graph .add_graph_input (utils .make_name ("accumulator" ), ctx .get_dtype (node .input [0 ]), ctx .get_shape (inputs_tensor ))
654
+ graph .add_graph_input (utils .make_name ("accumulator" ), inputs_tensor_dtype , inputs_tensor_shape )
655
+ if not isinstance (axis , int ): # axis is a tensor, we need to feed it to the loop graph
656
+ graph .add_graph_input (utils .make_name ("axis" ), onnx_pb .TensorProto .INT64 , [1 ])
657
+ axis = graph .input_names [- 1 ]
658
+ axis_node = graph .get_node_by_output (axis )
659
+ else :
660
+ axis_node = graph .make_const (utils .make_name ("axis" ), np .array ([axis ], "int64" ))
639
661
640
662
# main loop graph
641
663
loop_name = node .name + "/loop"
642
664
iter_num = GraphBuilder (graph ).make_unsqueeze ({'data' : graph .input_names [0 ], 'axes' : [0 ]})
643
665
one_node = graph .make_const (utils .make_name ("one" ), np .array (1 , "int64" ))
644
666
zero_node = graph .make_const (utils .make_name ("zero" ), np .array ([0 ], "int64" ))
645
-
667
+
646
668
add_node = graph .make_node ("Add" , inputs = [iter_num , one_node .output [0 ]],
647
669
outputs = [utils .make_name ("add" )], op_name_scope = loop_name )
648
670
649
671
if node .get_attr_value ("reverse" ):
650
- pad_tensors = [zero_node .output [0 ], add_node .output [0 ]]
672
+ pad_axis = [zero_node .output [0 ], add_node .output [0 ]]
651
673
start_slice = add_node .output [0 ]
652
- end_slice = graph .input_names [5 ]
674
+ end_slice = graph .input_names [3 ]
653
675
else :
654
- neg_node = graph .make_node ("Neg" , inputs = [add_node .output [0 ]], outputs = [utils .make_name ("neg" )], op_name_scope = loop_name )
655
- pad_tensors = [add_node .output [0 ], zero_node .output [0 ]]
676
+ neg_node = graph .make_node ("Neg" , inputs = [add_node .output [0 ]],
677
+ outputs = [utils .make_name ("neg" )], op_name_scope = loop_name )
678
+ pad_axis = [add_node .output [0 ], zero_node .output [0 ]]
656
679
start_slice = zero_node .output [0 ]
657
680
end_slice = neg_node .output [0 ]
658
-
659
- pads_node = cls .get_pads_node (graph , pad_tensors , graph . input_names [ 4 ], graph . input_names [ 3 ], loop_name )
660
- slice_node = graph .make_node ("Slice" , inputs = [ graph . input_names [ 2 ], start_slice , end_slice , graph . input_names [ 4 ] ],
661
- op_name_scope = loop_name , outputs = [ utils . make_name ( "slice" ) ])
681
+
682
+ pads_node = cls .get_pads_node (graph , pad_axis , axis , input_rank , is_pad_axis_const = False , base_name = loop_name )
683
+ slice_node = graph .make_node ("Slice" , op_name_scope = loop_name , outputs = [ utils . make_name ( "slice" ) ],
684
+ inputs = [ graph . input_names [ 2 ], start_slice , end_slice , axis_node . output [ 0 ] ])
662
685
if graph .get_dtype (slice_node .output [0 ]) != graph .get_dtype (one_node .output [0 ]):
663
- pad_const_node = graph .make_node ("Cast" , inputs = [one_node .output [0 ]], attr = {"to" : graph .get_dtype (slice_node .output [0 ])},
664
- op_name_scope = loop_name , outputs = [utils .make_name ("pad_const" )])
686
+ pad_const_node = graph .make_node ("Cast" , inputs = [one_node .output [0 ]],
687
+ attr = {"to" : graph .get_dtype (slice_node .output [0 ])},
688
+ op_name_scope = loop_name , outputs = [utils .make_name ("pad_const" )])
665
689
else :
666
690
pad_const_node = one_node
667
691
pad_node = graph .make_node ("Pad" , inputs = [slice_node .output [0 ], pads_node .output [0 ], pad_const_node .output [0 ]],
668
- op_name_scope = loop_name , outputs = [utils .make_name ("pad" )])
669
- mul_node = graph .make_node ("Mul" , inputs = [graph .input_names [- 1 ], pad_node .output [0 ]],
670
- op_name_scope = loop_name , outputs = [utils .make_name ("mul" )])
671
-
692
+ op_name_scope = loop_name , outputs = [utils .make_name ("pad" )])
693
+ mul_node = graph .make_node ("Mul" , inputs = [graph .input_names [4 ], pad_node .output [0 ]],
694
+ op_name_scope = loop_name , outputs = [utils .make_name ("mul" )],
695
+ shapes = [inputs_tensor_shape ], dtypes = [inputs_tensor_dtype ])
696
+
672
697
# manage loop outputs
673
698
output_cond_node = graph .make_node ("Identity" , inputs = [graph .input_names [1 ]], op_name_scope = loop_name ,
674
699
outputs = [utils .make_name ("condition_out" )])
675
700
output_inp_node = graph .make_node ("Identity" , inputs = [graph .input_names [2 ]], op_name_scope = loop_name ,
676
- outputs = [utils .make_name ("inputs_out" )])
677
- output_rank_node = graph .make_node ("Identity" , inputs = [graph .input_names [3 ]], op_name_scope = loop_name ,
678
- outputs = [utils .make_name ("rank_out" )])
679
- output_axis_node = graph .make_node ("Identity" , inputs = [graph .input_names [4 ]], op_name_scope = loop_name ,
680
- outputs = [utils .make_name ("axis_out" )])
681
- output_axis_length_plus_one_node = graph .make_node ("Identity" , inputs = [graph .input_names [5 ]],
682
- op_name_scope = loop_name , outputs = [utils .make_name ("axis_length_plus_one_out" )])
701
+ outputs = [utils .make_name ("inputs_out" )])
702
+ output_axis_length_plus_one_node = graph .make_node ("Identity" , inputs = [graph .input_names [3 ]],
703
+ op_name_scope = loop_name ,
704
+ outputs = [utils .make_name ("axis_length_plus_one_out" )])
683
705
output_acc_node = graph .make_node ("Identity" , inputs = [mul_node .output [0 ]], op_name_scope = loop_name ,
684
- outputs = [utils .make_name ("accumulator_out" )],
685
- shapes = [ctx .get_shape (node .input [0 ])], dtypes = [ctx .get_dtype (node .input [0 ])])
706
+ outputs = [utils .make_name ("accumulator_out" )])
686
707
687
708
graph .add_graph_output (output_cond_node .output [0 ]) # 1 condition output
688
709
graph .add_graph_output (output_inp_node .output [0 ]) # N loop carried dependencies outputs
689
- graph .add_graph_output (output_rank_node .output [0 ]) # N loop carried dependencies outputs
690
- graph .add_graph_output (output_axis_node .output [0 ]) # N loop carried dependencies outputs
691
710
graph .add_graph_output (output_axis_length_plus_one_node .output [0 ]) # N loop carried dependencies outputs
692
711
graph .add_graph_output (output_acc_node .output [0 ]) # N loop carried dependencies outputs
712
+
713
+ if not isinstance (axis , int ): # axis is a tensor, we need to feed it to the loop graph
714
+ output_axis_node = graph .make_node ("Identity" , inputs = [axis ], op_name_scope = loop_name ,
715
+ outputs = [utils .make_name ("axis_out" )])
716
+ graph .add_graph_output (output_axis_node .output [0 ]) # N loop carried dependencies outputs
693
717
return graph
694
718
695
719
@classmethod
696
- def get_pads_node (cls , graph , pad_tensors , axis_tensor , rank_tensor , base_name = "" ):
697
- zero_node = graph .make_const (utils .make_name ("zero" ), np .array ([0 ], "int64" ))
698
- one_node = graph .make_const (utils .make_name ("zero" ), np .array ([1 ], "int64" ))
699
-
700
- post_repeat_node = graph .make_node ("Sub" , inputs = [rank_tensor , axis_tensor ],
701
- outputs = [utils .make_name ("post_repeat" )], op_name_scope = base_name )
702
- post_repeat_node = graph .make_node ("Sub" , inputs = [post_repeat_node .output [0 ], one_node .output [0 ]],
703
- outputs = [utils .make_name ("post_repeat" )], op_name_scope = base_name )
704
-
705
- pre_pad_node = graph .make_node ("Tile" , inputs = [zero_node .output [0 ], axis_tensor ],
706
- attr = {"axis" : 0 }, outputs = [utils .make_name ("pre_pad" )], op_name_scope = base_name )
707
-
708
- post_pad_node = graph .make_node ("Tile" , inputs = [zero_node .output [0 ], post_repeat_node .output [0 ]],
709
- attr = {"axis" : 0 }, outputs = [utils .make_name ("post_pad" )], op_name_scope = base_name )
710
-
711
- pads_node = graph .make_node ("Concat" , attr = {"axis" : 0 }, outputs = [utils .make_name ("pads" )], op_name_scope = base_name ,
712
- inputs = [pre_pad_node .output [0 ], pad_tensors [0 ], post_pad_node .output [0 ],
713
- pre_pad_node .output [0 ], pad_tensors [1 ], post_pad_node .output [0 ]])
720
+ def get_pads_node (cls , graph , pad_axis , axis , rank , is_pad_axis_const = True , base_name = "" ):
721
+ if isinstance (axis , int ): # axis, is a const, we directly compute padding values
722
+ pre_pad = np .zeros (axis , "int64" )
723
+ post_pad = np .zeros (rank - axis - 1 , "int64" )
724
+ if is_pad_axis_const : # pylint: disable=R1705
725
+ pads = np .concatenate ([pre_pad , pad_axis [0 :1 ], post_pad ,
726
+ pre_pad , pad_axis [1 :2 ], post_pad ])
727
+ pads_node = graph .make_const (utils .make_name ("pads" ), pads )
728
+ return pads_node
729
+ else :
730
+ pre_pad_node = graph .make_const (utils .make_name ("pre_pad" ), pre_pad )
731
+ post_pad_node = graph .make_const (utils .make_name ("post_pad" ), post_pad )
732
+
733
+ else : # axis is a tensor, we need to compute padding values at runtime
734
+ if is_pad_axis_const :
735
+ pad_axis = [graph .make_const (utils .make_name ("pad" ),
736
+ np .array ([pad ], "int64" )).output [0 ] for pad in pad_axis ]
737
+
738
+ rank_tensor = graph .make_const (utils .make_name ("rank" ), np .array ([rank ], "int64" )).output [0 ]
739
+ zero_node = graph .make_const (utils .make_name ("zero" ), np .array ([0 ], "int64" ))
740
+ one_node = graph .make_const (utils .make_name ("zero" ), np .array ([1 ], "int64" ))
741
+
742
+ post_repeat_node = graph .make_node ("Sub" , inputs = [rank_tensor , axis ],
743
+ outputs = [utils .make_name ("post_repeat" )], op_name_scope = base_name )
744
+ post_repeat_node = graph .make_node ("Sub" , inputs = [post_repeat_node .output [0 ], one_node .output [0 ]],
745
+ outputs = [utils .make_name ("post_repeat" )], op_name_scope = base_name )
746
+
747
+ pre_pad_node = graph .make_node ("Tile" , inputs = [zero_node .output [0 ], axis ], op_name_scope = base_name ,
748
+ attr = {"axis" : 0 }, outputs = [utils .make_name ("pre_pad" )])
749
+
750
+ post_pad_node = graph .make_node ("Tile" , inputs = [zero_node .output [0 ], post_repeat_node .output [0 ]],
751
+ attr = {"axis" : 0 }, outputs = [utils .make_name ("post_pad" )],
752
+ op_name_scope = base_name )
753
+
754
+ pads_node = graph .make_node ("Concat" , attr = {"axis" : 0 }, outputs = [utils .make_name ("pads" )],
755
+ op_name_scope = base_name ,
756
+ inputs = [pre_pad_node .output [0 ], pad_axis [0 ], post_pad_node .output [0 ],
757
+ pre_pad_node .output [0 ], pad_axis [1 ], post_pad_node .output [0 ]])
714
758
return pads_node
715
759
716
760
0 commit comments