|
| 1 | +<!--- SPDX-License-Identifier: Apache-2.0 --> |
| 2 | + |
| 3 | +## Example of converting TensorFlow model with custom op to ONNX |
| 4 | + |
| 5 | +This document describes the ways for exporting TensorFlow model with a custom operator, exporting the operator to ONNX format, and adding the operator to ONNX Runtime for model inference. Tensorflow provides abundant set of operators, and also provides the extending implmentation to register as the new operators. The new custom operators are usually not recognized by tf2onnx conversion and onnxruntime. So the TensorFlow custom ops should be exported using a combination of existing and/or new custom ONNX ops. Once the operator is converted to ONNX format, users can implement and register it with ONNX Runtime for model inference. This document explains the details of this process end-to-end, along with an example. |
| 6 | + |
| 7 | + |
| 8 | +### Required Steps |
| 9 | + |
| 10 | + - [1](#step1) - Adding the Tensorflow custom operator implementation in C++ and registering it with TensorFlow |
| 11 | + - [2](#step2) - Exporting the custom Operator to ONNX, using: |
| 12 | + <br /> - a combination of existing ONNX ops |
| 13 | + <br /> or |
| 14 | + <br /> - a custom ONNX Operator |
| 15 | + - [3](#step3) - Adding the custom operator implementation and registering it in ONNX Runtime (required only if using a custom ONNX op in step 2) |
| 16 | + |
| 17 | + |
| 18 | +### Implement the Custom Operator |
| 19 | +Firstly, try to install the TensorFlow latest version (Nighly is better) build refer to [here](https://github.com/tensorflow/tensorflow#install). And then implement the custom operators saving in TensorFlow library format and the file usually ends with `.so`. We have a simple example of `AddOne`, which is adding one for a tensor. |
| 20 | + |
| 21 | + |
| 22 | +#### Define the op interface |
| 23 | +Specify the name of your op, its inputs (types and names) and outputs (types and names), as well as docstrings and any attrs the op might require. |
| 24 | +``` |
| 25 | +#include "tensorflow/core/framework/op.h" |
| 26 | +#include "tensorflow/core/framework/shape_inference.h" |
| 27 | +
|
| 28 | +using namespace tensorflow; |
| 29 | +
|
| 30 | +
|
| 31 | +REGISTER_OP("AddOne") |
| 32 | + .Input("add_one: int32") |
| 33 | + .Output("result: int32") |
| 34 | + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { |
| 35 | + c->set_output(0, c->input(0)); |
| 36 | + return Status::OK(); |
| 37 | + }); |
| 38 | +``` |
| 39 | + |
| 40 | +#### Implement the op kernel |
| 41 | +Create a class that extends `OpKernel` and overrides the `Compute()` method. The implementation is written to the function `Compute()`. |
| 42 | + |
| 43 | +``` |
| 44 | +#include "tensorflow/core/framework/op_kernel.h" |
| 45 | +
|
| 46 | +void AddOneKernelLauncher(const Tensor* t_in, const int n, Tensor* t_out); |
| 47 | +
|
| 48 | +class AddOneOp : public OpKernel { |
| 49 | +public: |
| 50 | + explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {} |
| 51 | +
|
| 52 | + void Compute(OpKernelContext* context) override { |
| 53 | + // Tensore in input |
| 54 | + const Tensor& input_tensor = context->input(0); |
| 55 | + auto input = input_tensor.flat<int32>(); |
| 56 | +
|
| 57 | + // Tensore in output |
| 58 | + Tensor* output_tensor = NULL; |
| 59 | + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); |
| 60 | + auto output = output_tensor->flat<int32>(); |
| 61 | +
|
| 62 | +#if GOOGLE_CUDA |
| 63 | + AddOneKernelLauncher(input, input.size(), output); |
| 64 | +#else |
| 65 | + const int N = input.size(); |
| 66 | + for (int i = 0; i < N; i++) output(i) += 1; |
| 67 | +#endif |
| 68 | + if (N > 0) output(0) = input(0); |
| 69 | + } |
| 70 | +}; |
| 71 | +``` |
| 72 | +Add the Register kernel build, |
| 73 | +``` |
| 74 | +REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_CPU), AddOneOp); |
| 75 | +``` |
| 76 | +Save below code in C++ `.cc` file, |
| 77 | + |
| 78 | +#### Using C++ compiler to compile the op |
| 79 | +Assuming you have g++ installed, here is the sequence of commands you can use to compile your op into a dynamic library. |
| 80 | +``` |
| 81 | +TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) |
| 82 | +TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) |
| 83 | +g++ -std=c++14 -shared add_one.cc -o add_one.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 |
| 84 | +``` |
| 85 | +After below steps, we can get a TensorFlow custom op library `add_one.so`. |
| 86 | + |
| 87 | + |
| 88 | +### Convert the Operator to ONNX |
| 89 | +To be able to use this custom ONNX operator for inference, we need to add our custom operator to an inference engine. If the operator can be conbinded with exsiting [ONNX standard operators](https://github.com/onnx/onnx/blob/main/docs/Operators.md). The case will be easier: |
| 90 | + |
| 91 | +1- using [--load_op_libraries](https://github.com/onnx/tensorflow-onnx#--load_op_libraries) in conversion command or `tf.load_op_library()` method in code to load the TensorFlow custom ops library. |
| 92 | + |
| 93 | +2- implement the op handler, registered it with the `@tf_op` decorator. Those handlers will be registered via the decorator on load of the module. [Here](https://github.com/onnx/tensorflow-onnx/tree/main/tf2onnx/onnx_opset) are examples of TensorFlow op hander implementations. |
| 94 | + |
| 95 | +``` |
| 96 | +import numpy as np |
| 97 | +import tensorflow as tf |
| 98 | +import tf2onnx |
| 99 | +import onnx |
| 100 | +import os |
| 101 | +from tf2onnx import utils |
| 102 | +from tf2onnx.handler import tf_op |
| 103 | +from tf2onnx.tf_loader import tf_placeholder |
| 104 | +
|
| 105 | +
|
| 106 | +DIR_PATH = os.path.realpath(os.path.dirname(__file__)) |
| 107 | +saved_model_path = os.path.join(DIR_PATH, "model.onnx") |
| 108 | +tf_library_path = os.path.join(DIR_PATH, "add_one.so") |
| 109 | +
|
| 110 | +
|
| 111 | +@tf_op("AddOne", onnx_op="Add") |
| 112 | +class AddOne: |
| 113 | + @classmethod |
| 114 | + def version_1(cls, ctx, node, **kwargs): |
| 115 | + node_shape = ctx.get_shape(node.input[0]) |
| 116 | + const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0] |
| 117 | + node.input.append(const_one) |
| 118 | +
|
| 119 | +@tf.function |
| 120 | +def func(x): |
| 121 | + AddOne = tf.load_op_library(tf_library_path) |
| 122 | + x_ = AddOne.add_one(x) |
| 123 | + output = tf.identity(x_, name="output") |
| 124 | + return output |
| 125 | +
|
| 126 | +spec = [tf.TensorSpec(shape=(2, 3), dtype=tf.int32, name="input")] |
| 127 | +
|
| 128 | +onnx_model, _ = tf2onnx.convert.from_function(func, input_signature=spec, opset=15) |
| 129 | +
|
| 130 | +with open(saved_model_path, "wb") as f: |
| 131 | + f.write(onnx_model.SerializeToString()) |
| 132 | +
|
| 133 | +onnx_model = onnx.load(saved_model_path) |
| 134 | +onnx.checker.check_model(onnx_model) |
| 135 | +``` |
| 136 | + |
| 137 | +3- Run in ONNXRuntime, using `InferenceSession` to do inference and get the result. |
| 138 | +``` |
| 139 | +import onnxruntime as ort |
| 140 | +input = np.arange(6).reshape(2,3).astype(np.int32) |
| 141 | +ort_session = ort.InferenceSession(saved_model_path) |
| 142 | +ort_inputs = {ort_session.get_inputs()[0].name: input} |
| 143 | +
|
| 144 | +ort_outs = ort_session.run(None, ort_inputs) |
| 145 | +print("input:", input, "\nAddOne ort_outs:", ort_outs) |
| 146 | +``` |
| 147 | + |
| 148 | + |
| 149 | +If the operator can not using existing ONNX standard operators only, you need to go to [implement the operator in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md#implement-the-operator-in-onnx-runtime). |
| 150 | + |
| 151 | +### References: |
| 152 | +1- [Create an custom TensorFlow op](https://www.tensorflow.org/guide/create_op) |
| 153 | + |
| 154 | +2- [ONNX Runtime: Adding a New Op](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#register-a-custom-operator) |
| 155 | + |
| 156 | +3- [PyTorch Custom Operators export to ONNX](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md) |
0 commit comments