Skip to content

Commit d5631b5

Browse files
committed
prelu
1 parent bfd4e70 commit d5631b5

File tree

1 file changed

+44
-5
lines changed

1 file changed

+44
-5
lines changed

tools/pnnx/tests/onnx/test_onnx_activation_ops.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from onnxscript import opset19 as op
1313

1414
@script()
15-
def Model(x: FLOAT["C","H","W"]):
15+
def Model(x: FLOAT["N","C","W"], y: FLOAT["N","C","H","W"], z: FLOAT["N","C","D","H","W"]):
16+
17+
prelu_slope = op.RandomNormal(seed=0.0, shape=[12])
18+
1619
return (
1720
op.Celu(x),
1821
op.Clip(x, min=0, max=2.3),
@@ -23,29 +26,65 @@ def Model(x: FLOAT["C","H","W"]):
2326
op.LeakyRelu(x),
2427
op.LogSoftmax(x),
2528
op.Mish(x),
26-
# op.PRelu(x),
29+
op.PRelu(x, op.Unsqueeze(prelu_slope, axes=[1])),
2730
op.Relu(x),
2831
op.Selu(x),
2932
op.Sigmoid(x),
3033
op.Softmax(x),
3134
op.Softplus(x),
3235
# op.Swish(x),
36+
37+
op.Celu(y),
38+
op.Clip(y, min=0, max=2.3),
39+
op.Elu(y),
40+
# op.Gelu(y),
41+
op.HardSigmoid(y),
42+
op.HardSwish(y),
43+
op.LeakyRelu(y),
44+
op.LogSoftmax(y),
45+
op.Mish(y),
46+
op.PRelu(y, op.Unsqueeze(prelu_slope, axes=[1,2])),
47+
op.Relu(y),
48+
op.Selu(y),
49+
op.Sigmoid(y),
50+
op.Softmax(y),
51+
op.Softplus(y),
52+
# op.Swish(y),
53+
54+
# op.Celu(z),
55+
op.Clip(z, min=0, max=2.3),
56+
# op.Elu(z),
57+
# op.Gelu(z),
58+
op.HardSigmoid(z),
59+
op.HardSwish(z),
60+
op.LeakyRelu(z),
61+
# op.LogSoftmax(z),
62+
op.Mish(z),
63+
op.PRelu(z, op.Unsqueeze(prelu_slope, axes=[1,2,3])),
64+
op.Relu(z),
65+
# op.Selu(z),
66+
op.Sigmoid(z),
67+
# op.Softmax(z),
68+
# op.Softplus(z),
69+
# op.Swish(z),
3370
)
3471

3572
def test():
3673
# save onnx
3774
onnx.save(Model.to_model_proto(), "test_onnx_activation_ops.onnx")
3875

3976
torch.manual_seed(0)
40-
x = torch.rand(3, 4, 5)
77+
x = torch.rand(1, 12, 64)
78+
y = torch.rand(1, 12, 48, 64)
79+
z = torch.rand(1, 12, 21, 28, 44)
4180

4281
# ort inference
4382
sess = ort.InferenceSession("test_onnx_activation_ops.onnx")
44-
a = tuple(torch.from_numpy(out) for out in sess.run(None, {"x": x.numpy()}))
83+
a = tuple(torch.from_numpy(out) for out in sess.run(None, {"x": x.numpy(), "y": y.numpy(), "z": z.numpy()}))
4584

4685
# onnx to pnnx and ncnn
4786
import os
48-
os.system("../../src/pnnx test_onnx_activation_ops.onnx inputshape=[3,4,5] inputshape2=[13,14,15]")
87+
os.system("../../src/pnnx test_onnx_activation_ops.onnx inputshape=[1,12,64],[1,12,48,64],[1,12,21,28,44] inputshape2=[7,12,22],[8,12,33,11],[9,12,9,12,13] fp16=0")
4988

5089
# pnnx inference
5190
import test_onnx_activation_ops_pnnx

0 commit comments

Comments
 (0)