Skip to content

Commit fd1eb40

Browse files
committed
Partially support the backward of loaded function model.
1 parent e0b1e64 commit fd1eb40

File tree

65 files changed

+1886
-255
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1886
-255
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.CompilerServices;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class DictionaryExtension
9+
{
10+
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second)
11+
{
12+
first = pair.Key;
13+
second = pair.Value;
14+
}
15+
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other)
16+
{
17+
foreach(var (key, value) in other)
18+
{
19+
dic[key] = value;
20+
}
21+
}
22+
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue)
23+
{
24+
if (dic.ContainsKey(key))
25+
{
26+
return dic[key];
27+
}
28+
return defaultValue;
29+
}
30+
}
31+
}

src/TensorFlowNET.Core/APIs/tf.gradients.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace Tensorflow
2121
{
2222
public partial class tensorflow
2323
{
24-
GradientTape _tapeSet;
24+
internal GradientTape _tapeSet;
2525

2626
/// <summary>
2727
/// Record operations for automatic differentiation.

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Operations;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
@@ -79,5 +81,10 @@ public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
7981
num_split: num_split,
8082
axis: axis,
8183
name: name);
84+
85+
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
86+
{
87+
return gen_ops.ensure_shape(x, shape, name);
88+
}
8289
}
8390
}

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public partial class c_api
6161
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);
6262

6363
[DllImport(TensorFlowLibName)]
64-
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
64+
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status);
6565

6666
/// <summary>
6767
/// Set `num_dims` to -1 to represent "unknown rank".

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
using System.Diagnostics;
2323
using System.IO;
2424
using System.Linq;
25+
using Tensorflow.Operations;
2526

2627
namespace Tensorflow
2728
{

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ public unsafe byte[] ToArray()
107107
}
108108
}
109109

110+
public void Release()
111+
{
112+
_handle.Dispose();
113+
_handle = null;
114+
}
115+
110116
public override string ToString()
111117
=> $"0x{_handle.DangerousGetHandle():x16}";
112118

src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public static IList<Trackable> list_objects(ObjectGraphView graph_view)
161161

162162
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
163163
{
164-
return full_list.TakeWhile(x =>
164+
return full_list.Where(x =>
165165
{
166166
var saveables = x.gather_saveables_for_checkpoint();
167167
return saveables is not null && saveables.Count > 0;

src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData>
109109
TrackableObjectGraph.Types.TrackableObject trackable_object = new();
110110
trackable_object.SlotVariables.AddRange(td.slot_variable_proto);
111111
trackable_object.Children.AddRange(td.children_proto);
112+
object_graph_proto.Nodes.Add(trackable_object);
112113
}
113114
return object_graph_proto;
114115
}

src/TensorFlowNET.Core/Contexts/Context.Config.cs

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Google.Protobuf;
1718
using System;
1819
using System.Diagnostics;
1920
using System.Linq;
21+
using Tensorflow.Common.Extensions;
2022

2123
namespace Tensorflow.Contexts
2224
{
@@ -25,12 +27,93 @@ namespace Tensorflow.Contexts
2527
/// </summary>
2628
public sealed partial class Context
2729
{
28-
public ConfigProto Config { get; set; } = new ConfigProto
30+
protected Device.PhysicalDevice[] _physical_devices;
31+
protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index;
32+
ConfigProto _config;
33+
public ConfigProto Config
2934
{
30-
GpuOptions = new GPUOptions
35+
get
3136
{
37+
_initialize_physical_devices();
38+
39+
var config = new ConfigProto();
40+
if(_config is not null)
41+
{
42+
config.MergeFrom(_config);
43+
}
44+
config.LogDevicePlacement = _log_device_placement;
45+
46+
config.DeviceCount["CPU"] = 0;
47+
config.DeviceCount["GPU"] = 0;
48+
foreach(var dev in _physical_devices)
49+
{
50+
if (config.DeviceCount.ContainsKey(dev.DeviceType))
51+
{
52+
config.DeviceCount[dev.DeviceType] += 1;
53+
}
54+
else
55+
{
56+
config.DeviceCount[dev.DeviceType] = 1;
57+
}
58+
}
59+
60+
var gpu_options = _compute_gpu_options();
61+
config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray());
62+
63+
return config;
64+
}
65+
set
66+
{
67+
_config = value;
68+
}
69+
}
70+
71+
protected void _initialize_physical_devices(bool reinitialize = false)
72+
{
73+
if(!reinitialize && _physical_devices is not null)
74+
{
75+
return;
76+
}
77+
var devs = list_physical_devices();
78+
_physical_devices = devs.Select(d => new Device.PhysicalDevice()
79+
{
80+
DeviceName = d.DeviceName,
81+
DeviceType = d.DeviceType
82+
}).ToArray();
83+
_physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i))
84+
.ToDictionary(x => x.Key, x => x.Value);
85+
86+
_import_config();
87+
}
88+
89+
protected void _import_config()
90+
{
91+
if(_config is null)
92+
{
93+
return;
94+
}
95+
if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus))
96+
{
97+
num_cpus = 1;
98+
}
99+
if(num_cpus != 1)
100+
{
101+
// TODO(Rinne): implement it.
32102
}
33-
};
103+
104+
var gpus = _physical_devices.Where(d => d.DeviceType == "GPU");
105+
if(gpus.Count() == 0)
106+
{
107+
return;
108+
}
109+
110+
if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count))
111+
{
112+
gpu_count = 0;
113+
}
114+
115+
// TODO(Rinne): implement it.
116+
}
34117

35118
ConfigProto MergeConfig()
36119
{

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,26 @@ public sealed partial class Context
3838
public string ScopeName { get; set; } = "";
3939
bool initialized = false;
4040
ContextSwitchStack context_switches;
41-
public FunctionCallOptions FunctionCallOptions { get; }
41+
protected FunctionCallOptions _function_call_options;
42+
public FunctionCallOptions FunctionCallOptions
43+
{
44+
get
45+
{
46+
if(_function_call_options is null)
47+
{
48+
var config = Config;
49+
_function_call_options = new FunctionCallOptions()
50+
{
51+
Config = config
52+
};
53+
}
54+
return _function_call_options;
55+
}
56+
set
57+
{
58+
_function_call_options = value;
59+
}
60+
}
4261

4362
SafeContextHandle _handle;
4463

@@ -62,7 +81,6 @@ public void ensure_initialized()
6281
if (initialized)
6382
return;
6483

65-
Config = MergeConfig();
6684
FunctionCallOptions.Config = Config;
6785
var config_str = Config.ToByteArray();
6886
var opts = new ContextOptions();
@@ -167,11 +185,29 @@ public bool has_function(string name)
167185
return c_api.TFE_ContextHasFunction(_handle, name);
168186
}
169187

188+
public void add_function(SafeFuncGraphHandle fn)
189+
{
190+
ensure_initialized();
191+
Status status = new();
192+
c_api.TFE_ContextAddFunction(_handle, fn, status);
193+
status.Check(true);
194+
}
195+
196+
public void remove_function(string name)
197+
{
198+
ensure_initialized();
199+
Status status = new();
200+
c_api.TFE_ContextRemoveFunction(_handle, name, status);
201+
status.Check(true);
202+
}
203+
170204
public void add_function_def(FunctionDef fdef)
171205
{
172206
ensure_initialized();
173-
var fdef_string = fdef.ToString();
174-
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length);
207+
var fdef_string = fdef.ToByteArray();
208+
Status status = new Status();
209+
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status);
210+
status.Check(true);
175211
}
176212

177213
public void restore_mode()

src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ namespace Tensorflow.Contexts
99
public class FunctionCallOptions
1010
{
1111
public ConfigProto Config { get; set; }
12+
public string ExecutorType { get; set; }
1213

13-
public string config_proto_serialized()
14+
public ByteString config_proto_serialized()
1415
{
15-
return Config.ToByteString().ToStringUtf8();
16+
return Config.ToByteString();
1617
}
1718
}
1819
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Linq;
1919
using Tensorflow.Contexts;
20+
using Tensorflow.Functions;
2021
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow.Eager
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations;
5+
6+
namespace Tensorflow.Eager
7+
{
8+
internal static class backprop_util
9+
{
10+
// TODO: add quantized_dtypes (after being supported).
11+
private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[]
12+
{
13+
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128,
14+
dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16
15+
});
16+
public static bool IsTrainable(Tensor tensor)
17+
{
18+
var dtype = _DTypeFromTensor(tensor);
19+
return _trainable_dtypes.Contains(dtype);
20+
}
21+
public static bool IsTrainable(TF_DataType dtype)
22+
{
23+
return _trainable_dtypes.Contains(dtype);
24+
}
25+
26+
private static TF_DataType _DTypeFromTensor(Tensor tensor)
27+
{
28+
var dtype = tensor.dtype;
29+
if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT)
30+
{
31+
CppShapeInferenceResult.Types.HandleData handle_data;
32+
if (tensor is EagerTensor)
33+
{
34+
handle_data = tensor.HandleData;
35+
}
36+
else
37+
{
38+
handle_data = handle_data_util.get_resource_handle_data(tensor);
39+
}
40+
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null &&
41+
handle_data.ShapeAndType.Count > 0)
42+
{
43+
var first_type = handle_data.ShapeAndType[0].Dtype;
44+
if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type))
45+
{
46+
return first_type.as_tf_dtype();
47+
}
48+
}
49+
}
50+
return dtype;
51+
}
52+
}
53+
}

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public partial class c_api
3131
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status);
3232

3333
[DllImport(TensorFlowLibName)]
34-
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size);
34+
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status);
3535

3636
[DllImport(TensorFlowLibName)]
3737
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy);
@@ -280,7 +280,7 @@ public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] ret
280280
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values);
281281

282282
[DllImport(TensorFlowLibName)]
283-
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status);
283+
public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status);
284284

285285
/// <summary>
286286
///

src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)