Skip to content

Commit b265026

Browse files
bzhang5bzhang
andauthored
Tensorflow serializer (OAID#1056)
* tensorflow serializer version 1 * tf serializer version 1 * apply code-format changes * suppoer tensorflow mobilenet * suppoer tensorflow mobilenet * apply code-format changes * apply code-format changes * Update CMakeLists.txt Co-authored-by: bzhang <[email protected]> Co-authored-by: bzhang5 <[email protected]>
1 parent 2a06d8c commit b265026

14 files changed

+3408
-0
lines changed

tools/convert_tool/CMakeLists.txt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,47 @@ list(APPEND CAFFE_SERIALIZER_SRCS ${CAFFE_PROTO_SRC})
6161
# NCNN
6262
file(GLOB_RECURSE NCNN_SERIALIZER_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/ncnn/*.cpp")
6363

64+
# TENSORFLOW
65+
file(GLOB_RECURSE TF_SERIALIZER_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/tensorflow/*.cpp")
66+
67+
list(APPEND TENGINE_LIB_SRCS ${serializer_src})
68+
69+
# the generated pb.cc
70+
set(TF_PROTO_SRC ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/graph.pb.cc
71+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/function.pb.cc
72+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/node_def.pb.cc
73+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/op_def.pb.cc
74+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/attr_value.pb.cc
75+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/tensor.pb.cc
76+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/tensor_shape.pb.cc
77+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/types.pb.cc
78+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/versions.pb.cc
79+
${CMAKE_CURRENT_BINARY_DIR}/tensorflow/resource_handle.pb.cc)
80+
81+
set(TF_PROTO_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow)
82+
set(TF_PROTO_OUT_PATH ${CMAKE_CURRENT_BINARY_DIR}/tensorflow)
83+
84+
ADD_CUSTOM_COMMAND(OUTPUT ${TF_PROTO_SRC}
85+
COMMAND mkdir -p ${TF_PROTO_OUT_PATH}
86+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/graph.proto
87+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/function.proto
88+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/node_def.proto
89+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/op_def.proto
90+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/attr_value.proto
91+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/tensor.proto
92+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/tensor_shape.proto
93+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/types.proto
94+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/versions.proto
95+
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --cpp_out=${TF_PROTO_OUT_PATH} --proto_path=${TF_PROTO_PATH} ${TF_PROTO_PATH}/resource_handle.proto
96+
#COMMAND mv ${TF_PROTO_OUT_PATH}/*.pb.h ${TF_PROTO_PATH}/../include/
97+
)
98+
99+
ADD_CUSTOM_TARGET(TF_SERIALIZER_TARGET DEPENDS ${TF_PROTO_OUT_PATH})
100+
101+
include_directories(${TF_PROTO_OUT_PATH})
102+
103+
list(APPEND TF_SERIALIZER_SRCS ${TF_PROTO_SRC})
104+
64105

65106
# SAVE GRAPH
66107
FILE(GLOB_RECURSE SAVE_GRAPH_SRCS "${CMAKE_SOURCE_DIR}/tools/save_graph/*.cpp" "${CMAKE_SOURCE_DIR}/tools/save_graph/*.c")
@@ -73,6 +114,7 @@ FILE(GLOB_RECURSE CONVERT_TOOL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/convert_tool.cp
73114
list(APPEND CONVERT_TOOL_SRCS ${ONNX_SERIALIZER_SRCS})
74115
list(APPEND CONVERT_TOOL_SRCS ${CAFFE_SERIALIZER_SRCS})
75116
list(APPEND CONVERT_TOOL_SRCS ${NCNN_SERIALIZER_SRCS})
117+
list(APPEND CONVERT_TOOL_SRCS ${TF_SERIALIZER_SRCS})
76118
list(APPEND CONVERT_TOOL_SRCS ${SAVE_GRAPH_SRCS})
77119
list(APPEND CONVERT_TOOL_SRCS ${GRAPH_OPT_SRCS})
78120

tools/convert_tool/convert_tool.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "onnx/onnx2tengine.hpp"
3232
#include "caffe/caffe2tengine.hpp"
3333
#include "ncnn/ncnn2tengine.hpp"
34+
#include "tensorflow/tf2tengine.hpp"
3435
#include "utils/graph_optimizer/graph_opt.hpp"
3536

3637
const char* help_params = "[Convert Tools Info]: optional arguments:\n"
@@ -178,6 +179,11 @@ int main(int argc, char* argv[])
178179
ncnn_serializer n2t;
179180
graph = n2t.ncnn2tengine(model_file, proto_file);
180181
}
182+
else if (file_format == "tensorflow")
183+
{
184+
tensorflow_serializer tf2t;
185+
graph = tf2t.tensorflow2tengine(model_file);
186+
}
181187
else
182188
{
183189
fprintf(stderr, "Convert model failed: support onnx only...\n");
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
syntax = "proto3";
2+
3+
package tensorflow;
4+
option cc_enable_arenas = true;
5+
option java_outer_classname = "AttrValueProtos";
6+
option java_multiple_files = true;
7+
option java_package = "org.tensorflow.framework";
8+
9+
import "tensor.proto";
10+
import "tensor_shape.proto";
11+
import "types.proto";
12+
13+
// Protocol buffer representing the value for an attr used to configure an Op.
14+
// Comment indicates the corresponding attr type. Only the field matching the
15+
// attr type may be filled.
16+
message AttrValue {
17+
// LINT.IfChange
18+
message ListValue {
19+
repeated bytes s = 2; // "list(string)"
20+
repeated int64 i = 3 [packed = true]; // "list(int)"
21+
repeated float f = 4 [packed = true]; // "list(float)"
22+
repeated bool b = 5 [packed = true]; // "list(bool)"
23+
repeated DataType type = 6 [packed = true]; // "list(type)"
24+
repeated TensorShapeProto shape = 7; // "list(shape)"
25+
repeated TensorProto tensor = 8; // "list(tensor)"
26+
repeated NameAttrList func = 9; // "list(attr)"
27+
}
28+
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
29+
30+
oneof value {
31+
bytes s = 2; // "string"
32+
int64 i = 3; // "int"
33+
float f = 4; // "float"
34+
bool b = 5; // "bool"
35+
DataType type = 6; // "type"
36+
TensorShapeProto shape = 7; // "shape"
37+
TensorProto tensor = 8; // "tensor"
38+
ListValue list = 1; // any "list(...)"
39+
40+
// "func" represents a function. func.name is a function's name or
41+
// a primitive op's name. func.attr.first is the name of an attr
42+
// defined for that function. func.attr.second is the value for
43+
// that attr in the instantiation.
44+
NameAttrList func = 10;
45+
46+
// This is a placeholder only used in nodes defined inside a
47+
// function. It indicates the attr value will be supplied when
48+
// the function is instantiated. For example, let us suppose a
49+
// node "N" in function "FN". "N" has an attr "A" with value
50+
// placeholder = "foo". When FN is instantiated with attr "foo"
51+
// set to "bar", the instantiated node N's attr A will have been
52+
// given the value "bar".
53+
string placeholder = 9;
54+
}
55+
}
56+
57+
// A list of attr names and their values. The whole list is attached
58+
// with a string name. E.g., MatMul[T=float].
59+
message NameAttrList {
60+
string name = 1;
61+
map<string, AttrValue> attr = 2;
62+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
syntax = "proto3";
2+
3+
package tensorflow;
4+
option cc_enable_arenas = true;
5+
option java_outer_classname = "FunctionProtos";
6+
option java_multiple_files = true;
7+
option java_package = "org.tensorflow.framework";
8+
9+
import "attr_value.proto";
10+
import "node_def.proto";
11+
import "op_def.proto";
12+
13+
// A library is a set of named functions.
14+
message FunctionDefLibrary {
15+
repeated FunctionDef function = 1;
16+
repeated GradientDef gradient = 2;
17+
}
18+
19+
// A function can be instantiated when the runtime can bind every attr
20+
// with a value. When a GraphDef has a call to a function, it must
21+
// have binding for every attr defined in the signature.
22+
//
23+
// TODO(zhifengc):
24+
// * device spec, etc.
25+
message FunctionDef {
26+
// The definition of the function's name, arguments, return values,
27+
// attrs etc.
28+
OpDef signature = 1;
29+
30+
// Attributes specific to this function definition.
31+
map<string, AttrValue> attr = 5;
32+
33+
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
34+
35+
// In both of the following fields, there is the need to specify an
36+
// output that is used as either the input to another node (in
37+
// `node_def`) or as a return value of the function (in `ret`).
38+
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
39+
// list in some cases (instead of just single outputs). Also, we
40+
// need to be able to deal with lists of unknown length (so the
41+
// output index may not be known at function definition time). So
42+
// we use the following format instead:
43+
// * "fun_in" where "fun_in" is the name of a function input arg in
44+
// the `signature` field above. This represents that input, whether
45+
// it is a single tensor or a list.
46+
// * "fun_in:0" gives the first element of a function input arg (a
47+
// non-list input is considered a list of length 1 for these
48+
// purposes).
49+
// * "node:out" where "node" is the name of a node in `node_def` and
50+
// "out" is the name one of its op's output arguments (the name
51+
// comes from the OpDef of the node's op). This represents that
52+
// node's output, whether it is a single tensor or a list.
53+
// Note: We enforce that an op's output arguments are never
54+
// renamed in the backwards-compatibility test.
55+
// * "node:out:0" gives the first element of a node output arg (a
56+
// non-list output is considered a list of length 1 for these
57+
// purposes).
58+
//
59+
// NOT CURRENTLY SUPPORTED (but may be in the future):
60+
// * "node:out:-1" gives last element in a node output list
61+
// * "node:out:1:" gives a list with all but the first element in a
62+
// node output list
63+
// * "node:out::-1" gives a list with all but the last element in a
64+
// node output list
65+
66+
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
67+
// may have values of type `placeholder` and the `input` field uses
68+
// the "output" format above.
69+
70+
// By convention, "op" in node_def is resolved by consulting with a
71+
// user-defined library first. If not resolved, "func" is assumed to
72+
// be a builtin op.
73+
repeated NodeDef node_def = 3;
74+
75+
// A mapping from the output arg names from `signature` to the
76+
// outputs from `node_def` that should be returned by the function.
77+
map<string, string> ret = 4;
78+
}
79+
80+
// GradientDef defines the gradient function of a function defined in
81+
// a function library.
82+
//
83+
// A gradient function g (specified by gradient_func) for a function f
84+
// (specified by function_name) must follow the following:
85+
//
86+
// The function 'f' must be a numerical function which takes N inputs
87+
// and produces M outputs. Its gradient function 'g', which is a
88+
// function taking N + M inputs and produces N outputs.
89+
//
90+
// I.e. if we have
91+
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
92+
// then, g is
93+
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
94+
// dL/dy1, dL/dy2, ..., dL/dy_M),
95+
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
96+
// loss function). dL/dx_i is the partial derivative of L with respect
97+
// to x_i.
98+
message GradientDef {
99+
string function_name = 1; // The function name.
100+
string gradient_func = 2; // The gradient function's name.
101+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
syntax = "proto3";
2+
3+
package tensorflow;
4+
option cc_enable_arenas = true;
5+
option java_outer_classname = "GraphProtos";
6+
option java_multiple_files = true;
7+
option java_package = "org.tensorflow.framework";
8+
9+
import "node_def.proto";
10+
import "function.proto";
11+
import "versions.proto";
12+
13+
// Represents the graph of operations
14+
message GraphDef {
15+
repeated NodeDef node = 1;
16+
17+
// Compatibility versions of the graph. See core/public/version.h for version
18+
// history. The GraphDef version is distinct from the TensorFlow version, and
19+
// each release of TensorFlow will support a range of GraphDef versions.
20+
VersionDef versions = 4;
21+
22+
// Deprecated single version field; use versions above instead. Since all
23+
// GraphDef changes before "versions" was introduced were forward
24+
// compatible, this field is entirely ignored.
25+
int32 version = 3 [deprecated = true];
26+
27+
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
28+
//
29+
// "library" provides user-defined functions.
30+
//
31+
// Naming:
32+
// * library.function.name are in a flat namespace.
33+
// NOTE: We may need to change it to be hierarchical to support
34+
// different orgs. E.g.,
35+
// { "/google/nn", { ... }},
36+
// { "/google/vision", { ... }}
37+
// { "/org_foo/module_bar", { ... }}
38+
// map<string, FunctionDefLib> named_lib;
39+
// * If node[i].op is the name of one function in "library",
40+
// node[i] is deemed as a function call. Otherwise, node[i].op
41+
// must be a primitive operation supported by the runtime.
42+
//
43+
//
44+
// Function call semantics:
45+
//
46+
// * The callee may start execution as soon as some of its inputs
47+
// are ready. The caller may want to use Tuple() mechanism to
48+
// ensure all inputs are ready in the same time.
49+
//
50+
// * The consumer of return values may start executing as soon as
51+
// the return values the consumer depends on are ready. The
52+
// consumer may want to use Tuple() mechanism to ensure the
53+
// consumer does not start until all return values of the callee
54+
// function are ready.
55+
FunctionDefLibrary library = 2;
56+
};
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
syntax = "proto3";
2+
3+
package tensorflow;
4+
option cc_enable_arenas = true;
5+
option java_outer_classname = "NodeProto";
6+
option java_multiple_files = true;
7+
option java_package = "org.tensorflow.framework";
8+
9+
import "attr_value.proto";
10+
11+
message NodeDef {
12+
// The name given to this operator. Used for naming inputs,
13+
// logging, visualization, etc. Unique within a single GraphDef.
14+
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
15+
string name = 1;
16+
17+
// The operation name. There may be custom parameters in attrs.
18+
// Op names starting with an underscore are reserved for internal use.
19+
string op = 2;
20+
21+
// Each input is "node:src_output" with "node" being a string name and
22+
// "src_output" indicating which output tensor to use from "node". If
23+
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
24+
// may optionally be followed by control inputs that have the format
25+
// "^node".
26+
repeated string input = 3;
27+
28+
// A (possibly partial) specification for the device on which this
29+
// node should be placed.
30+
// The expected syntax for this string is as follows:
31+
//
32+
// DEVICE_SPEC ::= PARTIAL_SPEC
33+
//
34+
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
35+
// CONSTRAINT ::= ("job:" JOB_NAME)
36+
// | ("replica:" [1-9][0-9]*)
37+
// | ("task:" [1-9][0-9]*)
38+
// | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") )
39+
//
40+
// Valid values for this string include:
41+
// * "/job:worker/replica:0/task:1/gpu:3" (full specification)
42+
// * "/job:worker/gpu:3" (partial specification)
43+
// * "" (no specification)
44+
//
45+
// If the constraints do not resolve to a single device (or if this
46+
// field is empty or not present), the runtime will attempt to
47+
// choose a device automatically.
48+
string device = 4;
49+
50+
// Operation-specific graph-construction-time configuration.
51+
// Note that this should include all attrs defined in the
52+
// corresponding OpDef, including those with a value matching
53+
// the default -- this allows the default to change and makes
54+
// NodeDefs easier to interpret on their own. However, if
55+
// an attr with a default is not specified in this list, the
56+
// default will be used.
57+
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
58+
// one of the names from the corresponding OpDef's attr field).
59+
// The values must have a type matching the corresponding OpDef
60+
// attr's type field.
61+
// TODO(josh11b): Add some examples here showing best practices.
62+
map<string, AttrValue> attr = 5;
63+
};

0 commit comments

Comments
 (0)