Skip to content

Commit ce749ce

Browse files
committed
add more details and changes
Signed-off-by: Deyu Huang <[email protected]>
1 parent 6a780e8 commit ce749ce

File tree

7 files changed

+158
-82
lines changed

7 files changed

+158
-82
lines changed

examples/tf_custom_op/add_one.cc

Lines changed: 0 additions & 46 deletions
This file was deleted.

examples/tf_custom_op/add_one.so

-30.2 KB
Binary file not shown.

examples/tf_custom_op/custom_op.md

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,31 @@ Specify the name of your op, its inputs (types and names) and outputs (types and
2424
```
2525
#include "tensorflow/core/framework/op.h"
2626
#include "tensorflow/core/framework/shape_inference.h"
27+
#include "tensorflow/core/framework/register_types.h"
2728
2829
using namespace tensorflow;
2930
3031
31-
REGISTER_OP("AddOne")
32-
.Input("add_one: int32")
33-
.Output("result: int32")
32+
// opregister
33+
REGISTER_OP("DoubleAndAddOne")
34+
.Input("x: T")
35+
.Output("result: T")
36+
.Attr("T: {float, double, int32}")
3437
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
35-
c->set_output(0, c->input(0));
38+
::tensorflow::shape_inference::ShapeHandle shape_x = c->input(0);
39+
if (!c->RankKnown(shape_x)) {
40+
c->set_output(0, c->UnknownShape());
41+
return Status::OK();
42+
}
43+
c->set_output(0, shape_x);
3644
return Status::OK();
37-
});
45+
})
46+
.Doc(R"doc(
47+
Calculate the value 2x + 1.
48+
x: A Tensor `Tensor`. Must be one of the types in `T`.
49+
50+
Returns: A `Tensor`. Has the same type with `x`.
51+
)doc");
3852
```
3953

4054
#### Implement the op kernel
@@ -43,35 +57,34 @@ Create a class that extends `OpKernel` and overrides the `Compute()` method. The
4357
```
4458
#include "tensorflow/core/framework/op_kernel.h"
4559
46-
void AddOneKernelLauncher(const Tensor* t_in, const int n, Tensor* t_out);
47-
48-
class AddOneOp : public OpKernel {
60+
template <typename T>
61+
class DoubleAndAddOneOp : public OpKernel {
4962
public:
50-
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
63+
explicit DoubleAndAddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
5164
5265
void Compute(OpKernelContext* context) override {
53-
// Tensore in input
66+
// Grab the input tensor
5467
const Tensor& input_tensor = context->input(0);
55-
auto input = input_tensor.flat<int32>();
68+
auto input = input_tensor.flat<T>();
5669
57-
// Tensore in output
70+
// Tensor in output
5871
Tensor* output_tensor = NULL;
5972
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
60-
auto output = output_tensor->flat<int32>();
73+
auto output = output_tensor->flat<T>();
6174
62-
#if GOOGLE_CUDA
63-
AddOneKernelLauncher(input, input.size(), output);
64-
#else
6575
const int N = input.size();
66-
for (int i = 0; i < N; i++) output(i) += 1;
67-
#endif
68-
if (N > 0) output(0) = input(0);
76+
for (int i = 0; i < N; i++) {
77+
output(i) = output(i) * T(2) + T(1);
78+
}
6979
}
7080
};
7181
```
7282
Add the Register kernel build,
7383
```
74-
REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_CPU), AddOneOp);
84+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
85+
.Device(DEVICE_CPU)
86+
.TypeConstraint<int>("T"),
87+
DoubleAndAddOneOp<int>);
7588
```
7689
Save below code in C++ `.cc` file,
7790

@@ -80,9 +93,9 @@ Assuming you have g++ installed, here is the sequence of commands you can use to
8093
```
8194
TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
8295
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
83-
g++ -std=c++14 -shared add_one.cc -o add_one.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
96+
g++ -std=c++14 -shared double_and_add_one.cc -o double_and_add_one.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
8497
```
85-
After below steps, we can get a TensorFlow custom op library `add_one.so`.
98+
After below steps, we can get a TensorFlow custom op library `double_and_add_one.so`.
8699

87100

88101
### Convert the Operator to ONNX
@@ -105,21 +118,30 @@ from tf2onnx.tf_loader import tf_placeholder
105118
106119
DIR_PATH = os.path.realpath(os.path.dirname(__file__))
107120
saved_model_path = os.path.join(DIR_PATH, "model.onnx")
108-
tf_library_path = os.path.join(DIR_PATH, "add_one.so")
121+
tf_library_path = os.path.join(DIR_PATH, "double_and_add_one.so")
109122
110123
111-
@tf_op("AddOne", onnx_op="Add")
112-
class AddOne:
124+
@tf_op("DoubleAndAddOne")
125+
class DoubleAndAddOne:
113126
@classmethod
114127
def version_1(cls, ctx, node, **kwargs):
128+
node.type = "Mul"
115129
node_shape = ctx.get_shape(node.input[0])
116-
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0]
117-
node.input.append(const_one)
130+
node_dtype = ctx.get_dtype(node.input[0])
131+
node_np_dtype = utils.map_onnx_to_numpy_type(node_dtype)
132+
133+
const_two = ctx.make_const(utils.make_name("cosnt_two"), np.array([2]).astype(node_np_dtype)).output[0]
134+
node.input.append(const_two)
135+
136+
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype=node_np_dtype)).output[0]
137+
op_name = utils.make_name(node.name)
138+
ctx.insert_new_node_on_output("Add", node.output[0], inputs=[node.output[0], const_one], name=op_name)
139+
118140
119141
@tf.function
120142
def func(x):
121-
AddOne = tf.load_op_library(tf_library_path)
122-
x_ = AddOne.add_one(x)
143+
custom_op = tf.load_op_library(tf_library_path)
144+
x_ = custom_op.double_and_add_one(x)
123145
output = tf.identity(x_, name="output")
124146
return output
125147
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
#include "tensorflow/core/framework/op.h"
6+
#include "tensorflow/core/framework/shape_inference.h"
7+
#include "tensorflow/core/framework/op_kernel.h"
8+
#include "tensorflow/core/framework/register_types.h"
9+
10+
using namespace tensorflow;
11+
12+
13+
// opregister
14+
REGISTER_OP("DoubleAndAddOne")
15+
.Input("x: T")
16+
.Output("result: T")
17+
.Attr("T: {float, double, int32}")
18+
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
19+
::tensorflow::shape_inference::ShapeHandle shape_x = c->input(0);
20+
if (!c->RankKnown(shape_x)) {
21+
c->set_output(0, c->UnknownShape());
22+
return Status::OK();
23+
}
24+
c->set_output(0, shape_x);
25+
return Status::OK();
26+
})
27+
.Doc(R"doc(
28+
Calculate the value 2x + 1.
29+
x: A Tensor `Tensor`. Must be one of the types in `T`.
30+
31+
Returns: A `Tensor`. Has the same type with `x`.
32+
)doc");
33+
34+
35+
// keneldefinition
36+
template <typename T>
37+
class DoubleAndAddOneOp : public OpKernel {
38+
public:
39+
explicit DoubleAndAddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
40+
41+
void Compute(OpKernelContext* context) override {
42+
// Grab the input tensor
43+
const Tensor& input_tensor = context->input(0);
44+
auto input = input_tensor.flat<T>();
45+
46+
// Tensor in output
47+
Tensor* output_tensor = NULL;
48+
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
49+
auto output = output_tensor->flat<T>();
50+
51+
const int N = input.size();
52+
for (int i = 0; i < N; i++) {
53+
output(i) = output(i) * T(2) + T(1);
54+
}
55+
}
56+
};
57+
58+
59+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
60+
.Device(DEVICE_CPU)
61+
.TypeConstraint<float>("T"),
62+
DoubleAndAddOneOp<float>);
63+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
64+
.Device(DEVICE_CPU)
65+
.TypeConstraint<double>("T"),
66+
DoubleAndAddOneOp<double>);
67+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
68+
.Device(DEVICE_CPU)
69+
.TypeConstraint<int>("T"),
70+
DoubleAndAddOneOp<int>);
71+
72+
73+
#define REGISTER_KERNEL(type) \
74+
REGISTER_KERNEL_BUILDER( \
75+
Name("DoubleAndAddOne").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
76+
DoubleAndAddOneOp<type>)
77+
78+
REGISTER_KERNEL(float);
79+
REGISTER_KERNEL(double);
80+
REGISTER_KERNEL(int);
81+
82+
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
83+
#undef REGISTER_KERNEL
84+
70.5 KB
Binary file not shown.

examples/tf_custom_op/addone_custom_op.py renamed to examples/tf_custom_op/double_and_add_one_custom_op.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,30 @@
1313

1414
DIR_PATH = os.path.realpath(os.path.dirname(__file__))
1515
saved_model_path = os.path.join(DIR_PATH, "model.onnx")
16-
tf_library_path = os.path.join(DIR_PATH, "add_one.so")
16+
tf_library_path = os.path.join(DIR_PATH, "double_and_add_one.so")
1717

1818

19-
@tf_op("AddOne", onnx_op="Add")
20-
class AddOne:
19+
@tf_op("DoubleAndAddOne")
20+
class DoubleAndAddOne:
2121
@classmethod
2222
def version_1(cls, ctx, node, **kwargs):
23+
node.type = "Mul"
2324
node_shape = ctx.get_shape(node.input[0])
24-
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0]
25-
node.input.append(const_one)
25+
node_dtype = ctx.get_dtype(node.input[0])
26+
node_np_dtype = utils.map_onnx_to_numpy_type(node_dtype)
27+
28+
const_two = ctx.make_const(utils.make_name("cosnt_two"), np.array([2]).astype(node_np_dtype)).output[0]
29+
node.input.append(const_two)
30+
31+
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype=node_np_dtype)).output[0]
32+
op_name = utils.make_name(node.name)
33+
ctx.insert_new_node_on_output("Add", node.output[0], inputs=[node.output[0], const_one], name=op_name)
34+
2635

2736
@tf.function
2837
def func(x):
29-
AddOne = tf.load_op_library(tf_library_path)
30-
x_ = AddOne.add_one(x)
38+
custom_op = tf.load_op_library(tf_library_path)
39+
x_ = custom_op.double_and_add_one(x)
3140
output = tf.identity(x_, name="output")
3241
return output
3342

@@ -50,3 +59,10 @@ def func(x):
5059

5160
ort_outs = ort_session.run(None, ort_inputs)
5261
print("input:", input, "\nort_outs:", ort_outs)
62+
63+
'''
64+
input: [[0 1 2]
65+
[3 4 5]]
66+
ort_outs: [array([[ 1, 3, 5],
67+
[ 7, 9, 11]], dtype=int32)]
68+
'''

examples/tf_custom_op/model.onnx

363 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)