Skip to content

Commit 1903700

Browse files
committed
Merge branch 'master' of github.com:AsakusaRinne/TensorFlow.NET into support_function_load
2 parents 2f62caa + a075bba commit 1903700

File tree

19 files changed

+537
-60
lines changed

19 files changed

+537
-60
lines changed

src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string,
5454
var g = to_graph.as_default();
5555
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
5656
object_map, call_with_mapped_captures, saveables_cache);
57-
tf.device("/cpu:0");
58-
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
57+
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
58+
{
59+
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception.
60+
return constant_op.constant(graph_proto.ToByteArray());
61+
});
5962
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
6063
g.Exit();
6164
return (named_saveable_objects, registered_savers);
@@ -66,8 +69,10 @@ public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string,
6669
{
6770
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
6871
object_map, call_with_mapped_captures, saveables_cache);
69-
tf.device("/cpu:0");
70-
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
72+
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
73+
{
74+
return constant_op.constant(graph_proto.ToString());
75+
});
7176
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
7277
return (named_saveable_objects, registered_savers);
7378
}

src/TensorFlowNET.Core/Checkpoint/checkpoint.cs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ public TrackableSaver(ObjectGraphView graph_view)
5959

6060
if(object_graph_tensor is null)
6161
{
62-
tf.device("/cpu:0");
63-
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
62+
tf_with(ops.device("/cpu:0"), _ =>
63+
{
64+
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
65+
});
6466
}
6567
else
6668
{
@@ -232,22 +234,26 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
232234
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);
233235

234236
Dictionary<Tensor, string> file_prefix_feed_dict;
235-
Tensor file_prefix_tensor;
237+
Tensor file_prefix_tensor = null;
236238
if (graph_building)
237239
{
238240
if(_file_prefix_placeholder is null)
239241
{
240-
tf.device("/cpu:0");
241-
_file_prefix_placeholder = constant_op.constant("model");
242+
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ =>
243+
{
244+
return constant_op.constant("model");
245+
});
242246
}
243247
file_prefix_tensor = _file_prefix_placeholder;
244248
file_prefix_feed_dict = new();
245249
file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
246250
}
247251
else
248252
{
249-
tf.device("/cpu:0");
250-
file_prefix_tensor = constant_op.constant(save_path);
253+
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ =>
254+
{
255+
return constant_op.constant(save_path);
256+
});
251257
file_prefix_feed_dict = null;
252258
}
253259
TrackableObjectGraph object_graph_proto = new();

src/TensorFlowNET.Core/Checkpoint/functional_saver.cs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_pref
117117

118118
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;
119119

120-
// tf python has code `with ops.device(restore_device):` here.
121-
tf.device(restore_device); // may be risky.
122-
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
120+
Tensor[] restored_tensors = null;
121+
tf_with(ops.device(restore_device), _ =>
122+
{
123+
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
124+
});
123125

124126
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
125127
int idx = 0;
@@ -243,11 +245,14 @@ public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
243245
options = new CheckpointOptions();
244246
}
245247

246-
tf.device("CPU"); // may be risky.
247-
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
248+
Tensor tmp_checkpoint_prefix = null;
249+
tf_with(ops.device("CPU"), _ =>
250+
{
251+
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
248252
constant_op.constant(".part"), constant_op.constant("_temp/part"));
249-
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
250-
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
253+
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
254+
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
255+
});
251256

252257
Operation save_fn()
253258
{
@@ -269,16 +274,24 @@ Operation save_fn()
269274
var saver = pair.Value;
270275
last_device = device;
271276
// skip the extra process of device name because of lack of API.
272-
tf.device(device);
273-
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
277+
Tensor shard_prefix = null;
278+
tf_with(ops.device(device), _ =>
279+
{
280+
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
281+
});
274282
saved_prefixes.Add(shard_prefix);
275-
sharded_saves.Add(saver.save(shard_prefix, options));
283+
tf_with(ops.device(device), _ =>
284+
{
285+
sharded_saves.Add(saver.save(shard_prefix, options));
286+
});
276287
}
277288
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
278289
{
279290
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
280-
tf.device(merge_device);
281-
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
291+
return tf_with(ops.device(merge_device), _ =>
292+
{
293+
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
294+
});
282295
}
283296
}
284297

@@ -312,8 +325,9 @@ IDictionary<string, Operation> restore_func()
312325
{
313326
var device = single_saver.Key;
314327
var saver = single_saver.Value;
315-
tf.device(device);
316-
var restored_tensor_dict = saver.restore(file_prefix, options);
328+
tf_with(ops.device(device), _ =>
329+
{
330+
var restored_tensor_dict = saver.restore(file_prefix, options);
317331

318332
foreach(var pair in restored_tensor_dict)
319333
{
@@ -405,21 +419,25 @@ public SaverDef to_proto()
405419
private Tensor _traced_save(Tensor file_prefix)
406420
{
407421
var save_op = save(file_prefix);
408-
tf.device("cpu:0");
409-
using (ops.control_dependencies(new object[]{ save_op }))
422+
return tf_with(ops.device("cpu:0"), _ =>
410423
{
411-
return array_ops.identity(file_prefix);
412-
}
424+
return tf_with(ops.control_dependencies(new object[] { save_op }), __ =>
425+
{
426+
return array_ops.identity(file_prefix);
427+
});
428+
});
413429
}
414430

415431
private Tensor _traced_restore(Tensor file_prefix)
416432
{
417433
var restore_op = restore(file_prefix);
418-
tf.device("cpu:0");
419-
using (ops.control_dependencies(restore_op.Values.ToArray()))
434+
return tf_with(ops.device("cpu:0"), _ =>
420435
{
421-
return array_ops.identity(file_prefix);
422-
}
436+
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ =>
437+
{
438+
return array_ops.identity(file_prefix);
439+
});
440+
});
423441
}
424442

425443
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)

src/TensorFlowNET.Core/Contexts/Context.Device.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using static Tensorflow.Binding;
2222
using Google.Protobuf;
2323
using Tensorflow.Device;
24+
using Tensorflow.Exceptions;
2425
using System.Collections.Generic;
2526

2627
namespace Tensorflow.Contexts
@@ -30,10 +31,30 @@ namespace Tensorflow.Contexts
3031
/// </summary>
3132
public sealed partial class Context
3233
{
34+
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new();
35+
internal List<LogicalDevice> _logical_devices = null;
36+
internal List<string> _context_devices = null;
37+
3338
ContextDevicePlacementPolicy _device_policy;
3439
bool _log_device_placement;
40+
int _num_gpus;
3541
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();
3642

43+
public string DeviceName { get; set; } = "";
44+
public DeviceSpec DeviceSpec { get; set; } = null;
45+
46+
internal List<string> Devices
47+
{
48+
get
49+
{
50+
if(_context_devices is null)
51+
{
52+
throw new AssertionError("Context must be initialized first.");
53+
}
54+
return _context_devices;
55+
}
56+
}
57+
3758
public void log_device_placement(bool enable)
3859
{
3960
if (_handle != null)
@@ -89,5 +110,57 @@ public PhysicalDevice[] list_physical_devices(string device_type = null)
89110

90111
return results.ToArray();
91112
}
113+
114+
public EagerDeviceContext device(string name)
115+
{
116+
return new EagerDeviceContext(this, name);
117+
}
118+
119+
internal void _set_device(string device_name, DeviceSpec device_spec)
120+
{
121+
DeviceSpec = device_spec;
122+
DeviceName = device_name;
123+
}
124+
125+
internal void _initialize_logical_devices()
126+
{
127+
List<LogicalDevice> logical_devices = new();
128+
List<string> context_devices = new();
129+
Status status = new();
130+
var device_list = c_api.TFE_ContextListDevices(_handle, status);
131+
status.Check(true);
132+
try
133+
{
134+
this._num_gpus = 0;
135+
string current_job = null;
136+
int current_task = -1;
137+
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++)
138+
{
139+
var dev_name = c_api.TF_DeviceListName(device_list, i, status);
140+
status.Check(true);
141+
context_devices.Add(DeviceUtils.canonical_name(dev_name));
142+
var spec = DeviceSpec.from_string(dev_name);
143+
if(spec.Job == "localhost")
144+
{
145+
spec = spec.replace(job: null, replica: -1, task: -1);
146+
}
147+
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType));
148+
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status);
149+
var dev_type = c_api.StringPiece(dev_type_memory);
150+
status.Check(true);
151+
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task)
152+
{
153+
_num_gpus++;
154+
}
155+
}
156+
}
157+
finally
158+
{
159+
_logical_devices = logical_devices;
160+
_context_devices = context_devices;
161+
}
162+
}
92163
}
164+
165+
public record class LogicalDevice(string name, string device_type);
93166
}

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ public sealed partial class Context
3434
public const int EAGER_MODE = 1;
3535

3636
int defaultExecutionMode = EAGER_MODE;
37-
public string DeviceName { get; set; } = "";
3837
public string ScopeName { get; set; } = "";
3938
bool initialized = false;
4039
ContextSwitchStack context_switches;
@@ -81,6 +80,9 @@ public void ensure_initialized()
8180
if (initialized)
8281
return;
8382

83+
Debug.Assert(_context_devices is null);
84+
85+
Config = MergeConfig();
8486
FunctionCallOptions.Config = Config;
8587
var config_str = Config.ToByteArray();
8688
var opts = new ContextOptions();
@@ -90,6 +92,7 @@ public void ensure_initialized()
9092
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
9193
_handle = c_api.TFE_NewContext(opts, status);
9294
status.Check(true);
95+
_initialize_logical_devices();
9396
initialized = true;
9497
}
9598

@@ -228,6 +231,7 @@ public void reset_context()
228231
{
229232
c_api.TFE_ContextClearCaches(_handle);
230233
}
234+
_device_parsing_cache.Clear();
231235
}
232236

233237
public static implicit operator SafeContextHandle(Context ctx)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Device;
5+
6+
namespace Tensorflow.Contexts
7+
{
8+
public class EagerDeviceContext : ITensorFlowObject
9+
{
10+
private Context _ctx;
11+
private string _device_name;
12+
private Stack<(string, DeviceSpec, DeviceSpec)> _stack;
13+
14+
public EagerDeviceContext(Context ctx, string device_name)
15+
{
16+
_ctx = ctx;
17+
_device_name = device_name;
18+
_stack = new Stack<(string, DeviceSpec, DeviceSpec)>();
19+
}
20+
public void __enter__()
21+
{
22+
var ctx = _ctx;
23+
var old_device_name = ctx.DeviceName;
24+
var old_device_spec = ctx.DeviceSpec;
25+
var new_device_name = _device_name;
26+
var cache_key = (old_device_name, new_device_name);
27+
DeviceSpec new_device_spec;
28+
if (Context._device_parsing_cache.ContainsKey(cache_key))
29+
{
30+
(new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key];
31+
}
32+
else
33+
{
34+
if(new_device_name is not null)
35+
{
36+
var device_spec = DeviceSpec.from_string(new_device_name);
37+
if (!string.IsNullOrEmpty(old_device_name))
38+
{
39+
new_device_spec = new DeviceSpec(old_device_spec);
40+
}
41+
else
42+
{
43+
ctx.ensure_initialized();
44+
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]);
45+
}
46+
new_device_spec = new_device_spec.make_merged_spec(device_spec);
47+
}
48+
else
49+
{
50+
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]);
51+
}
52+
new_device_name = new_device_spec.ToString();
53+
Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec);
54+
}
55+
ctx._set_device(new_device_name, new_device_spec);
56+
_stack.Push((old_device_name, old_device_spec, new_device_spec));
57+
}
58+
59+
public void __exit__()
60+
{
61+
var ctx = _ctx;
62+
var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop();
63+
ctx._set_device(old_device_name, old_device_spec);
64+
}
65+
66+
public void Dispose()
67+
{
68+
69+
}
70+
}
71+
}

0 commit comments

Comments
 (0)