|
| 1 | +syntax = "proto3"; |
| 2 | + |
| 3 | +package tensorflow; |
| 4 | +option cc_enable_arenas = true; |
| 5 | +option java_outer_classname = "FunctionProtos"; |
| 6 | +option java_multiple_files = true; |
| 7 | +option java_package = "org.tensorflow.framework"; |
| 8 | + |
| 9 | +import "attr_value.proto"; |
| 10 | +import "node_def.proto"; |
| 11 | +import "op_def.proto"; |
| 12 | + |
| 13 | +// A library is a set of named functions. |
| 14 | +message FunctionDefLibrary { |
| 15 | + repeated FunctionDef function = 1; |
| 16 | + repeated GradientDef gradient = 2; |
| 17 | +} |
| 18 | + |
| 19 | +// A function can be instantiated when the runtime can bind every attr |
| 20 | +// with a value. When a GraphDef has a call to a function, it must |
| 21 | +// have binding for every attr defined in the signature. |
| 22 | +// |
| 23 | +// TODO(zhifengc): |
| 24 | +// * device spec, etc. |
| 25 | +message FunctionDef { |
| 26 | + // The definition of the function's name, arguments, return values, |
| 27 | + // attrs etc. |
| 28 | + OpDef signature = 1; |
| 29 | + |
| 30 | + // Attributes specific to this function definition. |
| 31 | + map<string, AttrValue> attr = 5; |
| 32 | + |
| 33 | + // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21. |
| 34 | + |
| 35 | + // In both of the following fields, there is the need to specify an |
| 36 | + // output that is used as either the input to another node (in |
| 37 | + // `node_def`) or as a return value of the function (in `ret`). |
| 38 | + // Unlike the NodeDefs in GraphDef, we need to be able to specify a |
| 39 | + // list in some cases (instead of just single outputs). Also, we |
| 40 | + // need to be able to deal with lists of unknown length (so the |
| 41 | + // output index may not be known at function definition time). So |
| 42 | + // we use the following format instead: |
| 43 | + // * "fun_in" where "fun_in" is the name of a function input arg in |
| 44 | + // the `signature` field above. This represents that input, whether |
| 45 | + // it is a single tensor or a list. |
| 46 | + // * "fun_in:0" gives the first element of a function input arg (a |
| 47 | + // non-list input is considered a list of length 1 for these |
| 48 | + // purposes). |
| 49 | + // * "node:out" where "node" is the name of a node in `node_def` and |
| 50 | + // "out" is the name one of its op's output arguments (the name |
| 51 | + // comes from the OpDef of the node's op). This represents that |
| 52 | + // node's output, whether it is a single tensor or a list. |
| 53 | + // Note: We enforce that an op's output arguments are never |
| 54 | + // renamed in the backwards-compatibility test. |
| 55 | + // * "node:out:0" gives the first element of a node output arg (a |
| 56 | + // non-list output is considered a list of length 1 for these |
| 57 | + // purposes). |
| 58 | + // |
| 59 | + // NOT CURRENTLY SUPPORTED (but may be in the future): |
| 60 | + // * "node:out:-1" gives last element in a node output list |
| 61 | + // * "node:out:1:" gives a list with all but the first element in a |
| 62 | + // node output list |
| 63 | + // * "node:out::-1" gives a list with all but the last element in a |
| 64 | + // node output list |
| 65 | + |
| 66 | + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs |
| 67 | + // may have values of type `placeholder` and the `input` field uses |
| 68 | + // the "output" format above. |
| 69 | + |
| 70 | + // By convention, "op" in node_def is resolved by consulting with a |
| 71 | + // user-defined library first. If not resolved, "func" is assumed to |
| 72 | + // be a builtin op. |
| 73 | + repeated NodeDef node_def = 3; |
| 74 | + |
| 75 | + // A mapping from the output arg names from `signature` to the |
| 76 | + // outputs from `node_def` that should be returned by the function. |
| 77 | + map<string, string> ret = 4; |
| 78 | +} |
| 79 | + |
| 80 | +// GradientDef defines the gradient function of a function defined in |
| 81 | +// a function library. |
| 82 | +// |
| 83 | +// A gradient function g (specified by gradient_func) for a function f |
| 84 | +// (specified by function_name) must follow the following: |
| 85 | +// |
| 86 | +// The function 'f' must be a numerical function which takes N inputs |
| 87 | +// and produces M outputs. Its gradient function 'g', which is a |
| 88 | +// function taking N + M inputs and produces N outputs. |
| 89 | +// |
| 90 | +// I.e. if we have |
| 91 | +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), |
| 92 | +// then, g is |
| 93 | +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, |
| 94 | +// dL/dy1, dL/dy2, ..., dL/dy_M), |
| 95 | +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the |
| 96 | +// loss function). dL/dx_i is the partial derivative of L with respect |
| 97 | +// to x_i. |
| 98 | +message GradientDef { |
| 99 | + string function_name = 1; // The function name. |
| 100 | + string gradient_func = 2; // The gradient function's name. |
| 101 | +} |
0 commit comments