Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,22 @@ public static void add_checkpoint_values_check(TrackableObjectGraph object_graph
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}

/// <summary>
/// Traverse the object graph and list all accessible objects.
/// </summary>
/// <param name="object_graph_view"></param>
public static IList<Trackable> list_objects(ObjectGraphView graph_view)
{
return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
}

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;
});
}
}
98 changes: 98 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow.Checkpoint
{
public class CheckpointReader : IDisposable
{
private IntPtr _reader;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

public CheckpointReader(string filename)
{
Status status = new Status();
_reader = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}

public int HasTensor(string name)
{
return c_api.TF_CheckpointReaderHasTensor(_reader, name);
}

/// <summary>
/// Get the variable name.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public string GetVariable(int index)
{
return c_api.TF_CheckpointReaderGetVariable(_reader, index);
}

public int Size()
{
return c_api.TF_CheckpointReaderSize(_reader);
}

public TF_DataType GetVariableDataType(string name)
{
return c_api.TF_CheckpointReaderGetVariableDataType(_reader, name);
}

public Shape GetVariableShape(string name)
{
// TODO(Rinne): Change it to a constant.
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_reader, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}

public int GetVariableNumDims(string name)
{
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
}

public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
status.Check(true);
var shape = GetVariableShape(name);
if(dtype == TF_DataType.DtInvalid)
{
dtype = GetVariableDataType(name);
}
return new Tensor(tensor);
}

private void ReadAllShapeAndType()
{
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>();
int size = Size();
for(int i = 0; i < size; i++)
{
var name = GetVariable(i);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
VariableToDataTypeMap[name] = dtype;
VariableToShapeMap[name] = shape;
}
}

public void Dispose()
{
c_api.TF_DeleteCheckpointReader(_reader);
}
}
}
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ public static (IList<MySaveableObject>, object?) generate_saveable_objects(
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
var maybe_saveable = factory_data.factory;
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory);

// TODO: oneflow python has a process with callable `saveable_factory`.
// TODO: tensorflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
{
Expand Down Expand Up @@ -217,7 +217,7 @@ public static (IList<MySaveableObject>, object?) generate_saveable_objects(

public record class CheckpointFactoryData
(
Maybe<BaseResourceVariable, MySaveableObject> factory,
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
string name,
string checkpoint_key
);
29 changes: 29 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.InteropServices;

namespace Tensorflow
{
public unsafe partial class c_api
{
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_NewCheckpointReader(string filename, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderHasTensor(IntPtr reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern string TF_CheckpointReaderGetVariable(IntPtr reader, int index);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderSize(IntPtr reader);
[DllImport(TensorFlowLibName)]
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(IntPtr reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern void TF_CheckpointReaderGetVariableShape(IntPtr reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status);
}
}
Loading