Skip to content

Commit 8c8e4dd

Browse files
committed
sum
1 parent d5631b5 commit 8c8e4dd

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

tools/pnnx/src/pass_onnx.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
687687
sim_op_type = "aten::einsum";
688688
}
689689

690+
if (op_type == "Sum")
691+
{
692+
sim_op_type = "aten::add";
693+
}
694+
690695
// unaryop
691696
if (op_type == "Abs") sim_op_type = "aten::abs";
692697
if (op_type == "Acos") sim_op_type = "aten::acos";
@@ -1109,6 +1114,39 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
11091114
op->inputs.clear();
11101115
op->inputs.push_back(opm1_out);
11111116
}
1117+
1118+
if (op_type == "Sum")
1119+
{
1120+
// unroll sum
1121+
if (op->inputs.size() > 2)
1122+
{
1123+
std::vector<Operand*> more_inputs(op->inputs.begin() + 2, op->inputs.end());
1124+
op->inputs.resize(2);
1125+
1126+
Operator* last_op = op;
1127+
for (size_t j = 0; j < more_inputs.size(); j++)
1128+
{
1129+
Operand* x = more_inputs[j];
1130+
1131+
Operator* op1 = pnnx_graph.new_operator_after("aten::add", op->name + "_sum" + std::to_string(j + 1), last_op);
1132+
Operand* op1_in = pnnx_graph.new_operand(op->name + "_sum" + std::to_string(j));
1133+
Operand* op1_out = last_op->outputs[0];
1134+
op1_in->consumers.push_back(op1);
1135+
op1_in->producer = last_op;
1136+
last_op->outputs[0] = op1_in;
1137+
op1_out->producer = op1;
1138+
op1->inputs.push_back(op1_in);
1139+
op1->inputs.push_back(x);
1140+
x->consumers.clear();
1141+
x->consumers.push_back(op1);
1142+
op1->outputs.push_back(op1_out);
1143+
1144+
last_op = op1;
1145+
}
1146+
1147+
op->name = op->name + "_sum0";
1148+
}
1149+
}
11121150
}
11131151
else if (is_prim_op)
11141152
{

tools/pnnx/tests/onnx/test_onnx_math_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def Model(x: FLOAT["C","H","W"], y: FLOAT["C","H","W"]):
4141
op.Div(x, y),
4242
op.Min(x, y),
4343
op.Max(x, y),
44-
op.Pow(x, op.Abs(y))
44+
op.Pow(x, op.Abs(y)),
45+
46+
op.Sum(x, op.Relu(y)),
47+
op.Sum(x, op.Floor(y), y, op.Sin(y)),
4548
)
4649

4750
def test():

0 commit comments

Comments
 (0)