diff --git a/src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs b/src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs similarity index 100% rename from src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs rename to src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs diff --git a/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs similarity index 80% rename from src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs rename to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs index 2e758dbf1..6ceba445a 100644 --- a/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs +++ b/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs @@ -3,16 +3,16 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.Extensions +namespace Tensorflow.Common.Extensions { public static class JObjectExtensions { public static T? TryGetOrReturnNull(this JObject obj, string key) { var res = obj[key]; - if(res is null) + if (res is null) { - return default(T); + return default; } else { diff --git a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs new file mode 100644 index 000000000..6cf62e7b8 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Common.Extensions +{ + public static class LinqExtensions + { +#if NETSTANDARD2_0 + public static IEnumerable TakeLast(this IEnumerable sequence, int count) + { + return sequence.Skip(sequence.Count() - count); + } + + public static IEnumerable SkipLast(this IEnumerable sequence, int count) + { + return sequence.Take(sequence.Count() - count); + } +#endif + public static Tensors ToTensors(this IEnumerable tensors) + { + return new Tensors(tensors); + } + + public static void Deconstruct(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) + { + first = values.Item1; + second = values.Item2; + third = values.Item3; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs new file mode 100644 index 000000000..76bdd6133 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Common.Extensions +{ + public static class NestExtensions + { + public static Tensors ToTensors(this INestable tensors) + { + return new Tensors(tensors.AsNest()); + } + + public static Tensors? ToTensors(this Nest tensors) + { + return Tensors.FromNest(tensors); + } + + /// + /// If the nested object is already a nested type, this function could reduce it. + /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. + /// + /// + /// + /// + /// + public static Nest ReduceTo(this INestStructure input) where TIn: INestStructure + { + return Nest.ReduceFrom(input); + } + } +} diff --git a/src/TensorFlowNET.Core/Extensions/OneofExtension.cs b/src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs similarity index 100% rename from src/TensorFlowNET.Core/Extensions/OneofExtension.cs rename to src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs new file mode 100644 index 000000000..e05d3deb3 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs @@ -0,0 +1,130 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public class GeneralizedTensorShape: IEnumerable, INestStructure, INestable + { + public TensorShapeConfig[] Shapes { get; set; } + /// + /// create a single-dim generalized Tensor shape. + /// + /// + public GeneralizedTensorShape(int dim) + { + Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; + } + + public GeneralizedTensorShape(Shape shape) + { + Shapes = new TensorShapeConfig[] { shape }; + } + + public GeneralizedTensorShape(TensorShapeConfig shape) + { + Shapes = new TensorShapeConfig[] { shape }; + } + + public GeneralizedTensorShape(TensorShapeConfig[] shapes) + { + Shapes = shapes; + } + + public GeneralizedTensorShape(IEnumerable shape) + { + Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); + } + + public Shape ToSingleShape() + { + if (Shapes.Length != 1) + { + throw new ValueError("The generalized shape contains more than 1 dim."); + } + var shape_config = Shapes[0]; + Debug.Assert(shape_config is not null); + return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); + } + + public long ToNumber() + { + if(Shapes.Length != 1 || Shapes[0].Items.Length != 1) + { + throw new ValueError("The generalized shape contains more than 1 dim."); + } + var res = Shapes[0].Items[0]; + return res is null ? -1 : res.Value; + } + + public Shape[] ToShapeArray() + { + return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); + } + + public IEnumerable Flatten() + { + List result = new List(); + foreach(var shapeConfig in Shapes) + { + result.AddRange(shapeConfig.Items); + } + return result; + } + public INestStructure MapStructure(Func func) + { + List> lists = new(); + foreach(var shapeConfig in Shapes) + { + lists.Add(new Nest(shapeConfig.Items.Select(x => new Nest(func(x))))); + } + return new Nest(lists); + } + + public Nest AsNest() + { + Nest DealWithSingleShape(TensorShapeConfig config) + { + if (config.Items.Length == 0) + { + return Nest.Empty; + } + else if (config.Items.Length == 1) + { + return new Nest(config.Items[0]); + } + else + { + return new Nest(config.Items.Select(x => new Nest(x))); + } + } + + if(Shapes.Length == 0) + { + return Nest.Empty; + } + else if(Shapes.Length == 1) + { + return DealWithSingleShape(Shapes[0]); + } + else + { + return new Nest(Shapes.Select(s => DealWithSingleShape(s))); + } + } + + public IEnumerator GetEnumerator() + { + foreach (var shape in Shapes) + { + yield return shape.Items; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/INest.cs b/src/TensorFlowNET.Core/Common/Types/INest.cs new file mode 100644 index 000000000..001141ddc --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/INest.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This interface indicates that a class may have a nested structure and provide + /// methods to manipulate with the structure. + /// + public interface INestStructure: INestable + { + /// + /// Flatten the Nestable object. Node that if the object contains only one value, + /// it will be flattened to an enumerable with one element. + /// + /// + IEnumerable Flatten(); + /// + /// Construct a new object with the same nested structure. + /// + /// + /// + /// + INestStructure MapStructure(Func func); + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/INestable.cs b/src/TensorFlowNET.Core/Common/Types/INestable.cs new file mode 100644 index 000000000..7ce49f85a --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/INestable.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public interface INestable + { + Nest AsNest(); + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs b/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs new file mode 100644 index 000000000..427e71aaa --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This interface is used when some corresponding python methods have optional args. + /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while + /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` + /// as the parameter of the method. + /// + public interface IOptionalArgs + { + /// + /// The identifier of the class. It is not an argument but only something to + /// separate different OptionalArgs. + /// + string Identifier { get; } + } +} diff --git a/src/TensorFlowNET.Core/Extensions/NamedTuple.cs b/src/TensorFlowNET.Core/Common/Types/NamedTuple.cs similarity index 100% rename from src/TensorFlowNET.Core/Extensions/NamedTuple.cs rename to src/TensorFlowNET.Core/Common/Types/NamedTuple.cs diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs new file mode 100644 index 000000000..b67d11f42 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public static class Nest + { + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + /// + /// + public static Nest PackSequenceAs(INestable template, T[] flatItems) + { + return template.AsNest().PackSequence(flatItems); + } + + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + /// + /// + public static Nest PackSequenceAs(INestable template, List flatItems) + { + return template.AsNest().PackSequence(flatItems.ToArray()); + } + + /// + /// Flatten the nested object. + /// + /// + /// + /// + public static IEnumerable Flatten(INestable nestedObject) + { + return nestedObject.AsNest().Flatten(); + } + + /// + /// Map the structure with specified function. + /// + /// + /// + /// + /// + /// + public static INestStructure MapStructure(Func func, INestable nestedObject) + { + return nestedObject.AsNest().MapStructure(func); + } + + public static bool IsNested(INestable obj) + { + return obj.AsNest().IsNested(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.cs b/src/TensorFlowNET.Core/Common/Types/Nest.cs new file mode 100644 index 000000000..84a60402e --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/Nest.cs @@ -0,0 +1,458 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Extensions; + +namespace Tensorflow.Common.Types +{ + public enum NestType + { + Empty, + Node, + List, + Dictionary + } + + /// + /// A nested structure which may inclulde value, list and dictionary. + /// Note that dictionary does not ensure the data order. When using it as IEnumerable, + /// its order is depth-first. + /// + /// + public class Nest : INestStructure, IEnumerable + { + private static readonly Nest _empty = new Nest() + { + NestType = NestType.Empty, + }; + public static Nest Empty => _empty; + public NestType NestType { get; protected set; } + public string? Name { get; set; } + public T? Value { get; protected set; } + public List>? ListValue { get; protected set; } + public Dictionary>? DictValue { get; protected set; } + + protected Nest() { } + + public Nest(T value, string? name = null) + { + Value = value; + Name = name; + NestType = NestType.Node; + } + + public Nest(IEnumerable> values, string? name = null) + { + ListValue = values.ToList(); + Name = name; + NestType = NestType.List; + } + + public Nest(Dictionary> value, string? name = null) + { + DictValue = value; + Name = name; + NestType = NestType.Dictionary; + } + + public Nest(Nest other) + { + NestType = other.NestType; + Value = other.Value; + DictValue = other.DictValue; + ListValue = other.ListValue; + Name = other.Name; + } + + public virtual IEnumerable Flatten() + { + return FlattenInternal(this); + } + public virtual INestStructure MapStructure(Func func) + { + return MapStructureInternal(func); + } + + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + public virtual Nest PackSequence(T[] flatItems) + { + if(flatItems.Length == 0) + { + return Nest.Empty; + } + int index = 0; + return PackSequenceInternal(this, flatItems, ref index); + } + + private static Nest PackSequenceInternal(Nest template, T[] flatItems, ref int index) + { + if(template.NestType == NestType.Node) + { + if(index >= flatItems.Length) + { + throw new InvalidArgumentError("The template and flat items are not matched."); + } + return new Nest(flatItems[index++]); + } + else if(template.NestType == NestType.List) + { + List> nestedObjects = new List>(); + for (int i = 0; i < template.ListValue!.Count; i++) + { + nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); + } + return new Nest(nestedObjects); + } + else if(template.NestType == NestType.Node) + { + Dictionary> dict = new Dictionary>(); + foreach(var (key, value) in template.DictValue!) + { + dict[key] = PackSequenceInternal(value, flatItems, ref index); + } + return new Nest(dict); + } + // Consider Empty as invalid type. + throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); + } + + public virtual Nest AsNest() + { + return this; + } + + public virtual Nest MergeWith(Nest? other) + { + if(other is null || other == Nest.Empty) + { + return this; + } + if(this == Nest.Empty) + { + return other; + } + if(NestType == NestType.Node && other.NestType == NestType.Node) + { + return new Nest(new Nest[] { this, other }); + } + else if(NestType == NestType.List && other.NestType == NestType.List) + { + return new Nest(this.ListValue!.Concat(other.ListValue!)); + } + else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) + { + return new Nest(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); + } + else + { + return new Nest(new Nest[] { this, other }); + } + } + + /// + /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not + /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. + /// + /// + public bool IsNested() + { + if(NestType is NestType.Empty or NestType.Node) + { + return false; + } + else if(NestType is NestType.List) + { + foreach(var item in ListValue!) + { + if(item.NestType is NestType.List or NestType.Dictionary) + { + return true; + } + } + return false; + } + else + { + foreach (var item in DictValue!.Values) + { + if (item.NestType is NestType.List or NestType.Dictionary) + { + return true; + } + } + return false; + } + } + + [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] + public T this[int index] + { + get + { + bool success = FindInternal(this, index, out var result); + if (success) + { + return result; + } + else + { + throw new IndexOutOfRangeException(); + } + } + set + { + bool success = SetInternal(this, index, value); + if (!success) + { + throw new IndexOutOfRangeException(); + } + } + } + + /// + /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it + /// to `Nest[T]`. + /// + /// + /// + /// + public static Nest ReduceFrom(INestStructure input) where TOut: INestStructure + { + var nested = input.AsNest(); + return ReduceInternal(nested); + } + + private static Nest ReduceInternal(Nest node) where TOut : INestStructure + { + if(node.NestType == NestType.Empty) + { + return Nest.Empty; + } + else if(node.NestType == NestType.Node) + { + return node.Value!.AsNest(); + } + else if(node.NestType == NestType.List) + { + return new Nest(node.ListValue!.Select(x => ReduceInternal(x))); + } + else // Dictionary type + { + return new Nest(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); + } + } + + private static bool FindInternal(Nest node, int index, out T? result) + { + if (node.NestType == NestType.Node) + { + if(index == 0) + { + result = node.Value!; + return true; + } + result = default(T); + return false; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + if(index == 0) + { + return FindInternal(item, index, out result); + } + index--; + } + result = default(T); + return false; + } + else if(node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + if (index == 0) + { + return FindInternal(item, index, out result); + } + index--; + } + result = default(T); + return false; + } + else + { + result = default(T); + return false; + } + } + + private static bool SetInternal(Nest node, int index, T newValue) + { + if (node.NestType == NestType.Node) + { + if (index == 0) + { + node.Value = newValue; + return true; + } + return false; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + if (index == 0) + { + return SetInternal(item, index, newValue); + } + index--; + } + return false; + } + else if (node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + if (index == 0) + { + return SetInternal(item, index, newValue); + } + index--; + } + return false; + } + else + { + return false; + } + } + + private static IEnumerable FlattenInternal(Nest node) + { + if (node.NestType == NestType.Node) + { + yield return node.Value!; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + foreach(var val in FlattenInternal(item)) + { + yield return val; + } + } + } + else if (node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + foreach (var val in FlattenInternal(item)) + { + yield return val; + } + } + } + } + + private Nest MapStructureInternal(Func func) + { + if (NestType == NestType.Node) + { + return new Nest(func(Value!)); + } + else if (NestType == NestType.List) + { + List> outs = new List>(); + foreach (var item in ListValue!) + { + outs.Add(item.MapStructureInternal(func)); + } + return new Nest(outs); + } + else if (NestType == NestType.Dictionary) + { + Dictionary> outs = new Dictionary>(); + foreach (var (key, value) in DictValue!) + { + outs.Add(key, value.MapStructureInternal(func)); + } + return new Nest(outs); + } + else + { + return Nest.Empty; + } + } + + public IEnumerator GetEnumerator() + { + return Flatten().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + sb.Append("("); + WriteString(this, sb); + sb.Append(")"); + return sb.ToString(); + } + + private static void WriteString(Nest node, StringBuilder sb) + { + if (!string.IsNullOrEmpty(node.Name)) + { + sb.Append($"{node.Name}: "); + } + if (node.NestType == NestType.Node) + { + sb.Append(node.Value!.ToString()); + } + else if (node.NestType == NestType.List) + { + sb.Append("["); + for(int i = 0; i < node.ListValue!.Count; i++) + { + WriteString(node.ListValue![i], sb); + if(i != node.ListValue!.Count - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else if (node.NestType == NestType.Dictionary) + { + sb.Append("{"); + int count = node.DictValue!.Count; + int i = 0; + foreach (var (key, value) in node.DictValue!) + { + sb.Append($"{key}: "); + WriteString(value, sb); + if (i != count - 1) + { + sb.Append(", "); + } + i++; + } + sb.Append("}"); + } + else + { + sb.Append(""); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs new file mode 100644 index 000000000..554ca526d --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public class NestDictionary : INestStructure, IDictionary where TKey : notnull + { + public IDictionary Value { get; set; } + public NestDictionary(IDictionary dict) + { + Value = dict; + } + public IEnumerable Flatten() + { + return Value.Select(x => x.Value); + } + public INestStructure MapStructure(Func func) + { + return new NestList(Value.Select(x => func(x.Value))); + } + + public Nest AsNest() + { + return new Nest(Value.Values.Select(x => new Nest(x))); + } + + // Required IDictionary members + public int Count => Value.Count; + + public bool IsReadOnly => Value.IsReadOnly; + + public ICollection Keys => Value.Keys; + + public ICollection Values => Value.Values; + + public void Add(TKey key, TValue value) + { + Value.Add(key, value); + } + + public void Add(KeyValuePair item) + { + Value.Add(item); + } + + public void Clear() + { + Value.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return Value.Contains(item); + } + + public bool ContainsKey(TKey key) + { + return Value.ContainsKey(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + Value.CopyTo(array, arrayIndex); + } + + public IEnumerator> GetEnumerator() + { + return Value.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public bool Remove(TKey key) + { + return Value.Remove(key); + } + + public bool Remove(KeyValuePair item) + { + return Value.Remove(item); + } + + public bool TryGetValue(TKey key, out TValue value) + { + return Value.TryGetValue(key, out value); + } + + // Optional IDictionary members + public TValue this[TKey key] + { + get => Value[key]; + set => Value[key] = value; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestList.cs b/src/TensorFlowNET.Core/Common/Types/NestList.cs new file mode 100644 index 000000000..082187188 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestList.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// The implementation of a list that support nest structure, in which the depth is 1. + /// + /// + public sealed class NestList : INestStructure, IEnumerable + { + public List Value { get; set; } + public NestList(IEnumerable values) + { + Value = new List(values); + } + public IEnumerable Flatten() + { + return Value; + } + public INestStructure MapStructure(Func func) + { + return new NestList(Value.Select(x => func(x))); + } + + public Nest AsNest() + { + return new Nest(Value.Select(x => new Nest(x))); + } + + // Enumerator implementation + public IEnumerator GetEnumerator() + { + return Value.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestNode.cs b/src/TensorFlowNET.Core/Common/Types/NestNode.cs new file mode 100644 index 000000000..1dad421d9 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestNode.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// A nested structure with only one element. + /// + /// + public class NestNode : INestStructure + { + public T Value { get; set; } + public NestNode(T value) + { + Value = value; + } + public IEnumerable Flatten() + { + yield return Value; + } + public INestStructure MapStructure(Func func) + { + return new NestNode(func(Value)); + } + + public Nest AsNest() + { + return new Nest(Value); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs similarity index 95% rename from src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs rename to src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs index 7abcfde26..a36930eca 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs +++ b/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Linq; -namespace Tensorflow.Keras.Saving +namespace Tensorflow.Common.Types { public class TensorShapeConfig { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index 2585592c1..ed5a1d6dd 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -1,17 +1,15 @@ using Newtonsoft.Json; using System.Collections.Generic; +using Tensorflow.Keras.Layers.Rnn; namespace Tensorflow.Keras.ArgsDefinition.Rnn { + // TODO(Rinne): add regularizers. public class RNNArgs : AutoSerializeLayerArgs { - public interface IRnnArgCell : ILayer - { - object state_size { get; } - } [JsonProperty("cell")] // TODO: the cell should be serialized with `serialize_keras_object`. - public IRnnArgCell Cell { get; set; } = null; + public IRnnCell Cell { get; set; } = null; [JsonProperty("return_sequences")] public bool ReturnSequences { get; set; } = false; [JsonProperty("return_state")] @@ -34,6 +32,9 @@ public interface IRnnArgCell : ILayer public IInitializer KernelInitializer { get; set; } public IInitializer RecurrentInitializer { get; set; } public IInitializer BiasInitializer { get; set; } + public float Dropout { get; set; } = .0f; + public bool ZeroOutputForMask { get; set; } = false; + public float RecurrentDropout { get; set; } = .0f; // kernel_regularizer=None, // recurrent_regularizer=None, diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs new file mode 100644 index 000000000..64b500bba --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + public class RnnOptionalArgs: IOptionalArgs + { + public string Identifier => "Rnn"; + public Tensor Mask { get; set; } = null; + public Tensors Constants { get; set; } = null; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs new file mode 100644 index 000000000..1dfcbe9cf --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs @@ -0,0 +1,29 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + public class SimpleRNNCellArgs: AutoSerializeLayerArgs + { + [JsonProperty("units")] + public int Units { get; set; } + // TODO(Rinne): lack of initialized value of Activation. Merging keras + // into tf.net could resolve it. + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("recurrent_initializer")] + public IInitializer RecurrentInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f76693945..e94c8bf10 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,4 +1,5 @@ -using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.NumPy; using Tensorflow.Training; @@ -14,7 +15,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable List Layers { get; } List InboundNodes { get; } List OutboundNodes { get; } - Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); + Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs new file mode 100644 index 000000000..d12ed1ad6 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public interface IRnnCell: ILayer + { + GeneralizedTensorShape StateSize { get; } + GeneralizedTensorShape OutputSize { get; } + bool IsTFRnnCell { get; } + /// + /// Whether the optional RNN args are supported when appying the layer. + /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. + /// + bool SupportOptionalArgs { get; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs new file mode 100644 index 000000000..e73244a51 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public interface IStackedRnnCells : IRnnCell + { + int Count { get; } + IRnnCell this[int idx] { get; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs index 1a4245bf2..3a21db9d2 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Saving.Json { diff --git a/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs b/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs index d91d3161d..ea6fe976f 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using OneOf.Types; using Tensorflow.Keras.Saving.Json; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Saving { diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 976c764f2..7a3ecbf10 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -74,8 +74,3 @@ public override string ToString() => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; } } - -namespace System.Runtime.CompilerServices -{ - internal static class IsExternalInit { } -} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 492047c9f..88673bb5e 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -53,7 +53,7 @@ private Tensor _generate_init_val(Shape shape, TF_DataType dtype) // Compute the qr factorization var (q, r) = tf.linalg.qr(a, full_matrices: false); // Make Q uniform - var d = tf.linalg.tensor_diag_part(r); + var d = tf.linalg.tensor_diag_part(r.Single); q *= tf.sign(d); if (num_rows < num_cols) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index d3592514d..b2cda952e 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -11,6 +11,7 @@ namespace Tensorflow /// Basic LSTM recurrent network cell. /// The implementation is based on: http://arxiv.org/abs/1409.2329. /// + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] public class BasicLstmCell : LayerRnnCell { int _num_units; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index 17d51363f..3308aebb7 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -20,6 +20,7 @@ limitations under the License. namespace Tensorflow { + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] public class BasicRnnCell : LayerRnnCell { int _num_units; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs index 7394cb7f9..65de4fe90 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs @@ -19,6 +19,7 @@ limitations under the License. namespace Tensorflow { + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] public class LayerRnnCell : RnnCell { protected InputSpec inputSpec; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index ecc9ca116..26646b76a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -16,10 +16,12 @@ limitations under the License. using System; using System.Collections.Generic; +using Tensorflow.Common.Types; using Tensorflow.Keras; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Keras.Saving; using Tensorflow.NumPy; using Tensorflow.Operations; @@ -50,7 +52,8 @@ namespace Tensorflow /// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// for each `s` in `self.batch_size`. /// - public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] + public abstract class RnnCell : ILayer, IRnnCell { /// /// Attribute that indicates whether the cell is a TF RNN cell, due the slight @@ -142,7 +145,7 @@ private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_Data throw new NotImplementedException("_zero_state_tensors"); } - public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) + public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) { throw new NotImplementedException(); } @@ -173,5 +176,14 @@ public void adapt(Tensor data, int? batch_size = null, int? steps = null) { throw new NotImplementedException(); } + + public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) + { + throw new NotImplementedException(); + } + public GeneralizedTensorShape StateSize => throw new NotImplementedException(); + public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); + public bool IsTFRnnCell => throw new NotImplementedException(); + public bool SupportOptionalArgs => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs index cf1b50af6..ed65a08d7 100644 --- a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -48,6 +49,7 @@ public class _EagerTensorArray : TensorArray public override Tensor flow => _flow; bool _clear_after_read; List _tensor_array; + List _previous_read_indices; public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, @@ -61,16 +63,20 @@ public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = fal _dtype = dtype.as_base_dtype(); _dynamic_size = dynamic_size; _clear_after_read = clear_after_read; - _tensor_array = new List(); + _tensor_array = Enumerable.Repeat(null, size.numpy()).ToList(); + _previous_read_indices = new(); } public override TensorArray unstack(Tensor value, string name = null) { - return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate + var tensors = array_ops.unstack(value, name: name); + if(tensors.Length > _tensor_array.Count && !_dynamic_size) { - var num_elements = array_ops.shape(value)[0]; - return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); - }); + throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}"); + } + _tensor_array = tensors.ToList(); + // TODO(Rinne): revise the implementation. Here we should return `parent()`. + return this; } public TensorArray scatter(Tensor indices, Tensor value, string name = null) @@ -116,9 +122,19 @@ public void _maybe_colocate_with(Tensor value) _colocate_with.Add(value); } + private Tensor _maybe_zero(int ix) + { + var val = _tensor_array[ix]; + if(val is null) + { + val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype); + } + return val; + } + public override Tensor read(T index, string name = null) { - int index_int = -1; + int index_int; if (index is int int_index) index_int = int_index; else if (index is Tensor tensor_index) @@ -126,27 +142,75 @@ public override Tensor read(T index, string name = null) else throw new ValueError(""); + if(index_int >= _tensor_array.Count) + { + throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} "); + } + + var res = _tensor_array[index_int]; + if(res is null) + { + if (_previous_read_indices.Contains(index_int)) + { + throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " + + $"a previous read (perhaps try setting clear_after_read = false?)"); + } + else + { + res = _maybe_zero(index_int); + } + } + if (_clear_after_read) { _tensor_array[index_int] = null; + _previous_read_indices.Add(index_int); } - - return _tensor_array[index_int]; + return res; } public override TensorArray write(Tensor index, Tensor value, string name = null) { - if (_infer_shape) - _element_shape = _element_shape.merge_with(value.shape); - _tensor_array.add(value); - return this; + int index_int; + if(index is EagerTensor eager) + { + return write(eager.numpy(), value, name); + } + throw new InvalidArgumentError("The index is supposed to be an EagerTensor"); } public override TensorArray write(int index, T value, string name = null) { - var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); - var index_tensor = ops.convert_to_tensor(index, name: "index"); - return write(index_tensor, value_tensor, name: name); + int size = _tensor_array.Count; + if(index >= size) + { + if (!_dynamic_size) + { + throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " + + $"is: {size} "); + } + _tensor_array.AddRange(Enumerable.Repeat(null, index - size + 1)); + } + + Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + + if(_dtype != tensor.dtype) + { + throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " + + $"trying to write dtype {tensor.dtype.as_python_name()} "); + } + + if (!_element_shape.is_compatible_with(tensor.shape)) + { + throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})"); + } + + if (_infer_shape) + { + _element_shape = _element_shape.merge_with(tensor.shape); + } + _tensor_array[index] = tensor; + return this; } private Tensor size(string name = null) @@ -156,11 +220,26 @@ private Tensor size(string name = null) public override Tensor stack(string name = null) { - ops.colocate_with(_handle); - return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + if(_tensor_array.Count > 0) { - return gather(math_ops.range(0, size()), name: name); - }); + for(int i = 0; i < _tensor_array.Count; i++) + { + _maybe_zero(i); + } + } + if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined) + { + return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype); + } + else + { + return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype); + } + //ops.colocate_with(_handle); + //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + //{ + // return gather(math_ops.range(0, size()), name: name); + //}); } public override Tensor gather(Tensor indices, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/logging_ops.cs b/src/TensorFlowNET.Core/Operations/logging_ops.cs index e38e60b5b..3303cadc3 100644 --- a/src/TensorFlowNET.Core/Operations/logging_ops.cs +++ b/src/TensorFlowNET.Core/Operations/logging_ops.cs @@ -30,7 +30,7 @@ public Tensor print_v2(Tensor input, string output_stream = "stderr", string end name: name); return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string) - .SetAttributes(new { output_stream, end })); + .SetAttributes(new { output_stream, end })).SingleOrNull; } } } diff --git a/src/TensorFlowNET.Core/Operations/sort_ops.cs b/src/TensorFlowNET.Core/Operations/sort_ops.cs index 34b903230..db38a073b 100644 --- a/src/TensorFlowNET.Core/Operations/sort_ops.cs +++ b/src/TensorFlowNET.Core/Operations/sort_ops.cs @@ -44,7 +44,7 @@ public static Tensor argsort(Tensor values, Axis axis = null, string direction = { sorted = true })); - return indices; + return indices.Single; } public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null) diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 09f5b0770..b08b2e2b7 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -114,4 +114,9 @@ https://tensorflownet.readthedocs.io + + + + + diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index d063ee39f..cba8f9541 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; namespace Tensorflow { @@ -13,157 +14,231 @@ namespace Tensorflow /// and Tensor[] from Tensors implicitily. /// It works for tuple and scalar as well. /// - public class Tensors : IEnumerable, IDisposable + public sealed class Tensors : Nest, IDisposable { - List items = new List(); - - public TF_DataType dtype => items.First().dtype; - public Shape shape => items.First().shape; - public int rank => items.First().rank; - public Graph graph => items.First().graph; + public TF_DataType dtype => this.First().dtype; + public Shape shape => this.First().shape; + public int rank => this.First().rank; + public Graph graph => this.First().graph; public bool IsList { get; set; } - public int Length => items.Count(); + public int Length => this.Count(); + /// + /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. + /// + public Tensor Single + { + get + { + if (Length != 1) + { + throw new ValueError("Tensors with more than one tensor cannot be " + + "implicitly converted to Tensor."); + } + return this.First(); + } + } - public Tensor this[int index] + /// + /// Return a Tensor if `Tensors` has only one tensor, and return null when `Tensors` is empty, + /// otherwise throw an exception. + /// + public Tensor? SingleOrNull { - get => items[index]; - set => items[index] = value; + get + { + if (Length > 1) + { + throw new ValueError($"Tensors with {Length} tensor cannot be " + + "implicitly converted to Tensor."); + } + return this.FirstOrDefault(); + } } public Tensor this[params string[] slices] - => items.First()[slices]; - public Tensors(params Tensor[] tensors) + => this.First()[slices]; + + public Tensors(Tensor tensor) : base(tensor) + { + + } + + private Tensors(Nest nested) : base(nested) + { + + } + + public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest(x))) + { + + } + + public Tensors(IEnumerable tensors): base(tensors.Select(x => new Nest(x))) { - items.AddRange(tensors); + } - public Tensors(IEnumerable tensors) + public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) { - items.AddRange(tensors); + } - public Tensors(NDArray nd) + public bool IsSingle() { - items.Add(ops.convert_to_tensor(nd)); + return Length == 1; } - public IEnumerator GetEnumerator() + public new Tensors MergeWith(Nest? other) { - foreach (var tensor in items) - yield return tensor; + return FromNest(base.MergeWith(other)); } + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + + "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] public void Add(Tensor tensor) - => items.Add(tensor); + { + if(NestType == NestType.Dictionary) + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + else if(NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(Value), new Nest(tensor) }; + Value = null; + } + else + { + ListValue.Add(new Nest(tensor)); + } + } + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + + "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] public void AddRange(IEnumerable tensors) - => items.AddRange(tensors); + { + if (NestType == NestType.Dictionary) + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + else if (NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(Value) }; + ListValue.AddRange(tensors.Select(x => new Nest(x))); + Value = null; + } + else + { + ListValue.AddRange(tensors.Select(x => new Nest(x))); + } + } + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + + "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] public void Insert(int index, Tensor tensor) - => items.Insert(index, tensor); - - IEnumerator IEnumerable.GetEnumerator() - => GetEnumerator(); + { + if (NestType == NestType.List) + { + ListValue.Insert(index, new Nest(tensor)); + } + else if(NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(Value) }; + ListValue.Insert(index, new Nest(tensor)); + Value = null; + } + else + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + } public string[] StringData() { - EnsureSingleTensor(this, "nnumpy"); - return this[0].StringData(); + return Single.StringData(); } public string StringData(int index) { - EnsureSingleTensor(this, "nnumpy"); - return this[0].StringData(index); + return Single.StringData(index); } public NDArray numpy() { - EnsureSingleTensor(this, "nnumpy"); - return this[0].numpy(); + return Single.numpy(); } + [Obsolete] public T[] ToArray() where T: unmanaged { - EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); - return this[0].ToArray(); + return Single.ToArray(); } #region Explicit Conversions public unsafe static explicit operator bool(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to bool"); - return (bool)tensor[0]; + return (bool)tensor.Single; } public unsafe static explicit operator sbyte(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to sbyte"); - return (sbyte)tensor[0]; + return (sbyte)tensor.Single; } public unsafe static explicit operator byte(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to byte"); - return (byte)tensor[0]; + return (byte)tensor.Single; } public unsafe static explicit operator ushort(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to ushort"); - return (ushort)tensor[0]; + return (ushort)tensor.Single; } public unsafe static explicit operator short(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to short"); - return (short)tensor[0]; + return (short)tensor.Single; } public unsafe static explicit operator int(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to int"); - return (int)tensor[0]; + return (int)tensor.Single; } public unsafe static explicit operator uint(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to uint"); - return (uint)tensor[0]; + return (uint)tensor.Single; } public unsafe static explicit operator long(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to long"); - return (long)tensor[0]; + return (long)tensor.Single; } public unsafe static explicit operator ulong(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to ulong"); - return (ulong)tensor[0]; + return (ulong)tensor.Single; } public unsafe static explicit operator float(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to byte"); - return (byte)tensor[0]; + return (byte)tensor.Single; } public unsafe static explicit operator double(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to double"); - return (double)tensor[0]; + return (double)tensor.Single; } public unsafe static explicit operator string(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to string"); - return (string)tensor[0]; + return (string)tensor.Single; } public static explicit operator object[](Tensors tensors) - => tensors.items.ToArray(); + => tensors.Flatten().ToArray(); #endregion #region Implicit Conversions @@ -183,56 +258,44 @@ public static implicit operator Tensors(Tensor[] tensors) public static implicit operator Tensors(List tensors) => new Tensors(tensors.ToArray()); - public static implicit operator Tensor(Tensors tensors) - => tensors.FirstOrDefault(); + public static implicit operator Tensor(Tensors? tensors) + => tensors?.SingleOrNull; public static implicit operator Tensor[](Tensors tensors) - => tensors.items.ToArray(); - + => tensors.Flatten().ToArray(); #endregion - public void Deconstruct(out Tensor a, out Tensor b) + public static Tensors? FromNest(Nest nested) { - a = items[0]; - b = items[1]; + if(nested == Nest.Empty) + { + return null; + } + return new Tensors(nested); } - private static void EnsureSingleTensor(Tensors tensors, string methodnName) + public void Deconstruct(out Tensor a, out Tensors? b) { - if(tensors.Length == 0) - { - throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); - } - else if(tensors.Length > 1) - { - throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); - } + a = this.First(); + b = Length == 1? null : new Tensors(this.Skip(1)); } public override string ToString() { - if(items.Count == 1) + if(Length == 1) { - return items[0].ToString(); + return this.First().ToString(); } else { - StringBuilder sb = new StringBuilder(); - sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); - for(int i = 0; i < items.Count; i++) - { - var tensor = items[i]; - sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); - } - sb.Append("]\n"); - return sb.ToString(); + return $"Totally {Length} tensors: {base.ToString()}"; } } public void Dispose() { - foreach (var item in items) - item.Dispose(); + foreach (var tensor in this) + tensor.Dispose(); } } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index eb94f4d05..3ba3ce78b 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -36,6 +36,7 @@ namespace Tensorflow.Util // (np.array([3, 4]), tf.constant([3, 4])))` // + [Obsolete] public static class nest { diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 80403ad6a..30b73e82f 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -20,8 +20,11 @@ limitations under the License. using System.Collections.Generic; using Tensorflow.Functions; using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; using static Tensorflow.Graphs.SubGraphUtility; +using Tensorflow.Util; +using Tensorflow.Common.Types; namespace Tensorflow.Keras { @@ -450,5 +453,535 @@ public Tensor conv2d_transpose(Tensor x, return x; } + + public (Tensors, Tensors, Tensors) rnn( + Func step_function, // args:inputs, states, return:output, new_states + Tensors inputs, // inputs is a tuple of tensors (one per input sequence) + Tensors initial_states, + bool go_backwards = false, + Tensor? mask = null, + Tensors? constants = null, + bool unroll = false, + Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not + bool time_major = false, + bool zero_output_for_mask = false, + bool return_all_outputs = true) + { + + Tensor swap_batch_timestep(Tensor input_t) + { + var axes = Enumerable.Range(0, input_t.rank).ToArray(); + axes[0] = 1; + axes[1] = 0; + return tf.transpose(input_t, axes); + } + + if (!time_major) + { + inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors(); + } + + var flatted_inptus = Nest.Flatten(inputs).ToList(); + var first_flatted_input = flatted_inptus[0]; + var time_steps = first_flatted_input.shape[0]; + var batch = first_flatted_input.shape[1]; + var time_steps_t = (int)first_flatted_input.shape[0]; + + foreach (var input_ in flatted_inptus) + { + input_.shape.with_rank_at_least(3); + } + + if (mask != null) + { + if (mask.dtype != TF_DataType.TF_BOOL) + { + mask = tf.cast(mask, TF_DataType.TF_BOOL); + } + + if (mask.rank == 2) + { + mask = tf.expand_dims(mask, -1); + } + + if (!time_major) + { + mask = swap_batch_timestep(mask); + } + + } + + // tf.where needs its condition tensor to be the same shape as its two + // result tensors, but in our case the condition (mask) tensor is + // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. + // So we need to broadcast the mask to match the shape of inputs. + // That's what the tile call does, it just repeats the mask along its + // second dimension n times. + + Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) + { + if (!mask_t.IsSingle()) + { + throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); + } + + if (!input_t.IsSingle()) + { + throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); + } + + var rank_diff = input_t.rank - mask_t.rank; + for (int i = 0; i < rank_diff; i++) + { + mask_t = tf.expand_dims(mask_t, -1); + } + var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); + return tf.tile(mask_t, multiples); + } + + Tensors outputs = new Tensors(); + Tensors output_time_zero = new Tensors(); + Tensors last_output = new Tensors(); + Tensors new_states = new Tensors(); + if (unroll) + { + if (time_steps == 0) + { + throw new ValueError("Unrolling requires a fixed number of timesteps."); + } + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple + //var states = Tuple.Create(initial_states); + var states = initial_states; + + var successive_states = new Tensors(); + var successive_outputs = new Tensors(); + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + + + Tensors _process_single_input_t(Tensor input_t) + { + var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim + if (go_backwards) + { + unstaked_input_t = unstaked_input_t.Reverse().ToArray(); + } + return unstaked_input_t; + } + + // TODO(Wanglongzhi2001) + Tensors processed_input; + if (!inputs.IsSingle()) + { + processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo().ToTensors(); + } + else + { + processed_input = _process_single_input_t(inputs); + } + + object _get_input_tensor(int time) + { + List inp = new List(); + foreach (var t_ in processed_input) + { + inp.Add(t_[time]); + } + return Nest.PackSequenceAs(inputs, inp); + } + + if (mask != null) + { + var mask_list = tf.unstack(mask); + if (go_backwards) + { + mask_list.Reverse(); + } + + for (int i = 0; i < time_steps; i++) + { + // TODO(Wanglongzhi2001),deal with _get_input_tensor + var inp = _get_input_tensor(i); + var mask_t = mask_list[i]; + // TODO + var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); + + var tiled_mask_t = _expand_mask(mask_t, output); + + Tensors prev_output; + if (successive_outputs == null) + { + prev_output = tf.zeros_like(output); + } + else + { + prev_output = successive_outputs[successive_outputs.Length - 1]; + } + + output = tf.where(tiled_mask_t, output, prev_output); + + var flat_states = Nest.Flatten(states).ToList(); + var flat_new_states = Nest.Flatten(newStates).ToList(); + + var tiledMaskT = flat_states + .Select(s => _expand_mask(mask_t, s)) + .ToArray(); + var tuple = Tuple.Create(tiledMaskT); + + List flat_final_states = new List(); + foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states)) + { + flat_final_states.Add(tf.where(m, s, ps)); + } + + states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(states); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { states }; + } + + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + + if (zero_output_for_mask) + { + last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + } + else // mask is null + { + for (int i = 0; i < time_steps; i++) + { + var inp = _get_input_tensor(i); + var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); + states = newStates; + + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(newStates); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { newStates }; + } + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + } + } + } + else // unroll == false + { + var states = initial_states; + // Create input tensor array, if the inputs is nested tensors, then it + // will be flattened first, and tensor array will be created one per + // flattened tensor. + var input_ta = new List(); + for (int i = 0; i < flatted_inptus.Count; i++) + { + input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); + } + + foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) + { + if (!go_backwards) + { + ta.unstack(input_); + } + else + { + ta.unstack(reverse(input_, 0)); + } + } + + // Get the time(0) input and compute the output for that, the output will + // be used to determine the dtype of output tensor array. Don't read from + // input_ta due to TensorArray clear_after_read default to True. + var inps = new Tensors(); + foreach (var inp in flatted_inptus) + { + inps.Add(inp[0]); + } + var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors(); + + // output_time_zero is used to determine the cell output shape and its + // dtype. the value is discarded. + (output_time_zero, _) = step_function((Tensor)input_time_zero, + constants is null ? initial_states : initial_states.MergeWith(constants)); + + int output_ta_size = return_all_outputs ? time_steps_t : 1; + var output_ta = new List(); + for (int i = 0; i < output_time_zero.ToList().Count; i++) + { + var Out = output_time_zero.ToList()[i]; + output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + } + + var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + + + Func? masking_fn; + Func? compute_masked_output = null; + if (mask != null) + { + if (go_backwards) + { + mask = tf.reverse(mask, axis: new[] { 0 }); + } + var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); + mask_ta = mask_ta.unstack(mask); + + masking_fn = (time) => + { + return mask_ta.read(time); + }; + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var tiled_mask_t = new Tensors(); + foreach (var o in flat_out) + { + tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + } + + Tensors res = new Tensors(); + foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList())) + { + res.Add(tf.where(m, o, fm)); + } + return res; + }; + } + // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + else if (input_length is Tensor) + { + if (go_backwards) + { + var max_len = tf.reduce_max(input_length, axis: 0); + var rev_input_length = tf.subtract(max_len - 1, input_length); + + masking_fn = (time) => + { + return tf.less(rev_input_length, time); + }; + } + else + { + masking_fn = (time) => + { + return tf.greater(input_length, time); + }; + } + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var res = new List(); + foreach (var (o, zo) in zip(flat_out, flat_mask)) + { + res.Add(tf.where(mask_t, o, zo)); + } + return res; + }; + } + else + { + masking_fn = null; + } + + Func cond = (time) => (time < time_steps_t); + int parallel_iterations = 32; + if (masking_fn != null) + { + // Mask for the T output will be base on the output of T - 1. In the + // case T = 0, a zero filled tensor will be used. + var flat_zero_output = new Tensors(); + foreach (var o in Nest.Flatten(output_time_zero)) + { + flat_zero_output.Add(tf.zeros_like(o)); + } + + var prev_output = flat_zero_output; + var output_ta_t = output_ta; + Tensor _step(Tensor time) + { + /* + RNN step function. + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + Returns: + Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + */ + + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); + var mask_t = masking_fn(time); + var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants)); + // mask output + var flat_output = Nest.Flatten(output).ToList(); + + var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + + // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // mask states + var flat_state = states.ToList(); + var flat_new_state = new_states_internal.ToList(); + + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.shape = state.shape; + } + } + + var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); + + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + // TODO(Wanglongzhi2001),deal with zip output_ta_t + foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + { + output_ta_t.Add(ta.write(ta_index_to_write, Out)); + } + + new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); + + output_ta = output_ta_t; + new_states = new_states_internal; + return time + 1; + + } + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); + } + else + { + var output_ta_t = output_ta; + new_states = states; + Tensor _step(Tensor time) + { + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); + var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); + var flat_state = new_states.Flatten().ToList(); + var flat_new_state = new_states_internal.Flatten().ToList(); + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.shape = state.shape; + } + } + var flat_output = Nest.Flatten(output); + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + output_ta_t = zip(output_ta_t, flat_output).Select(item => + { + var (ta, out_) = item; + return ta.write(ta_index_to_write, out_); + }).ToList(); + + new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); + output_ta = output_ta_t; + new_states = new_states_internal; + return time + 1; + } + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); + } + //Tensors outputs = new Tensors(); + foreach (var o in output_ta) + { + outputs.Add(o.stack()); + } + foreach (var o in outputs) + { + last_output.Add(o[-1]); + } + outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); + last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); + + } + + Func set_shape; + set_shape = (output_) => + { + if (output_ is Tensor) + { + var shape = output_.shape.as_int_list(); + if (return_all_outputs) + { + shape[0] = (int)time_steps; + } + else + { + shape[0] = 1; + } + shape[1] = (int)batch; + output_.shape = shape; + } + return output_; + }; + + outputs = Nest.MapStructure(set_shape, outputs).ToTensors(); + if (!time_major) + { + outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors(); + } + return (last_output, outputs, new_states); + + } + + public Tensor reverse(Tensor input, int axis) + { + return reverse(input, new int[] { axis }); + } + + public Tensor reverse(Tensor input, int[] axes) + { + return tf.reverse(input, axes); + } + + public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false) + { + if (!is_ragged_output) + { + return output; + } + + throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index e768bd0bd..7347585f8 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; @@ -81,7 +82,7 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs) } else { - _buildInputShape = new Saving.TensorShapeConfig(); + _buildInputShape = new TensorShapeConfig(); } if (outputs.Any(x => x.KerasHistory == null)) @@ -325,7 +326,7 @@ void BuildMapHelper(Tensor tensor, nodes_in_decreasing_depth.append(node); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var tensor_dict = new Dictionary>(); // map input values diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index c04304580..a0358f074 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -1,4 +1,5 @@ using System.Threading; +using Tensorflow.Common.Types; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -8,11 +9,11 @@ public partial class Layer /// /// Wraps `call`, applying pre- and post-processing steps. /// - /// + /// /// /// /// - public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) + public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) { if (callContext.Value == null) callContext.Value = new CallContext(); @@ -30,7 +31,7 @@ public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) if (!built) MaybeBuild(inputs); - var outputs = Call(inputs, state: state, training: training); + var outputs = Call(inputs, state: states, training: training); // memory leak // _set_connectivity_metadata_(inputs, outputs); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 5942efd92..2f758a850 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -32,7 +32,7 @@ limitations under the License. using static Tensorflow.Binding; using Tensorflow.Framework; using Tensorflow.Sessions; - +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Engine { @@ -332,7 +332,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null) /// /// /// - protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected virtual Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if(ReplacedCall is not null) { diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 83702b23a..7b35d5477 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -1,8 +1,8 @@ using System.Diagnostics; +using Tensorflow.Common.Types; using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Losses; -using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 278747515..6a468ad27 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -21,6 +21,7 @@ limitations under the License. using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; using static Tensorflow.KerasApi; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Engine { @@ -143,7 +144,7 @@ public void add(ILayer layer) } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (!_has_explicit_input_shape) { diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 739c0d56f..23f36c862 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -29,7 +30,7 @@ public override void build(KerasShapesWrapper input_shape) base.build(input_shape); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; output = tf.where(output > 0f, output, diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index 17636302f..81fefb314 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -4,7 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; -using static Tensorflow.Binding; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { public class Exponential : Layer @@ -17,7 +17,7 @@ public override void build(KerasShapesWrapper input_shape) { base.build(input_shape); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; return tf.exp(output); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs index b498d1b94..e0f91380b 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs @@ -3,6 +3,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { @@ -10,7 +11,7 @@ public class HardSigmoid : Layer { public HardSigmoid ( LayerArgs args ) : base(args) { // hard sigmoid has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null ) { Tensor x = inputs; return tf.clip_by_value( tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs index 1fbbf4eaf..cfbd0186d 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs @@ -3,6 +3,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -19,7 +20,7 @@ public LeakyReLu(LeakyReLuArgs args) : base(args) this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { return tf.nn.leaky_relu(inputs, alpha: alpha); } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 53101fbb4..2e943d5f7 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -22,7 +23,7 @@ public override void build(KerasShapesWrapper input_shape) { } base.build(input_shape); } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; return tf.where(output > 0f, tf.multiply(scale, output), diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs index 3ffae27f6..d018128d5 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -11,8 +12,8 @@ public class Softmax : Layer { public Softmax ( SoftmaxArgs args ) : base(args) { axis = args.axis; } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { + Tensor x = inputs.Length == 2 ? inputs[0] + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) : inputs; Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs index e82b01982..1e6c59b42 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -10,7 +11,7 @@ public class Softplus : Layer { public Softplus ( LayerArgs args ) : base(args) { // Softplus has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor x = inputs; return tf.log( tf.add(tf.exp(x), 1f)); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs index 59329fd44..5ad33e99d 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -10,7 +11,7 @@ public class Softsign : Layer { public Softsign ( LayerArgs args ) : base(args) { // Softsign has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor x = inputs; // x / (abs(x) + 1) return tf.div(x, tf.add(1f, tf.abs(x))); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs index 1dcb92b31..ed0d105a6 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -10,7 +11,7 @@ public class Swish : Layer { public Swish ( LayerArgs args ) : base(args) { // Swish has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor x = inputs; // x / (1 + exp(-x)) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs index 99b803942..7e90cf9d8 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs @@ -3,6 +3,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -13,7 +14,7 @@ public Tanh(LayerArgs args) : base(args) { // Tanh has no arguments } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor x = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 1348e19cf..19b292727 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; /// /// Base class for attention layers that can be used in sequence DNN/CNN models. @@ -114,7 +115,7 @@ public virtual Tensor _calculate_scores(Tensor query, Tensor key) => return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensors _inp; Tensors _mask = null; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs index 701724d5b..75dd4a41a 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -6,6 +6,7 @@ using static Tensorflow.KerasApi; using System; using System.Linq; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -252,7 +253,7 @@ public Tensors _compute_attention( return (attention_output, attention_scores); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensors _inp; Tensor _mask = null; @@ -349,7 +350,7 @@ protected Tensors call(Tensors inputs, //} if (return_attention_scores) - return (attention_output, attention_scores); + return (attention_output, attention_scores.Single); return attention_output; } } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index bbd49acd2..94ad79141 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -20,6 +20,7 @@ limitations under the License. using Tensorflow.Keras.Utils; using static Tensorflow.KerasApi; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -83,7 +84,7 @@ public override void build(KerasShapesWrapper input_shape) _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var inputs_shape = array_ops.shape(inputs); var batch_size = inputs_shape[0]; diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index c575362c0..d8e00d520 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -103,7 +104,7 @@ public override void build(KerasShapesWrapper input_shape) _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = false, IOptionalArgs? optional_args = null) { var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); if (use_bias) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index aa6617ddc..db5d626ed 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -18,6 +18,7 @@ limitations under the License. using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -69,7 +70,7 @@ public override void build(KerasShapesWrapper input_shape) built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index fb604f77e..0cbd50846 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.ArgsDefinition.Core; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -189,7 +190,7 @@ public override Shape ComputeOutputShape(Shape input_shape) // return new dict(base_config.items().ToList() + config.items().ToList()); //} - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); if (this.bias != null) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 9487a7d00..87b42bb7b 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -15,6 +15,7 @@ limitations under the License. ******************************************************************************/ using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -66,7 +67,7 @@ public override void build(KerasShapesWrapper input_shape) _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 7df654eeb..bcbb20d88 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -5,6 +5,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -21,7 +22,7 @@ public override void build(KerasShapesWrapper input_shape) _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { return _merge_function(inputs); } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index d02d2509c..655581576 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -146,7 +147,7 @@ bool _support_zero_size_input() return false; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor outputs = null; var training_tensor = training == null diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index e90c04029..1898f24c8 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; @@ -101,7 +102,7 @@ public override Shape ComputeOutputShape(Shape input_shape) return input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor outputs = null; var inputs_dtype = inputs.dtype.as_base_dtype(); diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs index a65154bf4..987b56bc4 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs @@ -14,6 +14,7 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; @@ -157,7 +158,7 @@ public override void adapt(Tensor data, int? batch_size = null, int? steps = nul base.adapt(data, batch_size: batch_size, steps: steps); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (_args.Invert) { diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs index d62fb63a4..ffaabec97 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -12,7 +13,7 @@ public GlobalAveragePooling1D(Pooling1DArgs args) { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, 1, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs index 000e4b8b9..e06665173 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -12,7 +13,7 @@ public GlobalAveragePooling2D(Pooling2DArgs args) { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, (1, 2), false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs index 2de4671ca..15695e8a7 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -12,7 +13,7 @@ public GlobalMaxPooling1D(Pooling1DArgs args) { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (data_format == "channels_last") return math_ops.reduce_max(inputs, 1, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs index b7e2c9452..76db858da 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -12,7 +13,7 @@ public GlobalMaxPooling2D(Pooling2DArgs args) { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (data_format == "channels_last") return math_ops.reduce_max(inputs, (1, 2), false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs index a2f4c51b6..81a340199 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs @@ -18,6 +18,7 @@ limitations under the License. using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -36,7 +37,7 @@ public Pooling1D(Pooling1DArgs args) input_spec = new InputSpec(ndim: 3); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; inputs = tf.expand_dims(inputs, pad_axis); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs index 270322559..f83f1e152 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs @@ -17,6 +17,7 @@ limitations under the License. using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -36,7 +37,7 @@ public Pooling2D(Pooling2DArgs args) input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs index 5620a916c..20d2a53d5 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs @@ -1,6 +1,6 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; - +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { /// @@ -15,7 +15,7 @@ public CategoryEncoding(CategoryEncodingArgs args) : base(args) this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var depth = args.NumTokens; var max_value = tf.reduce_max(inputs); diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs index 5fc581af9..7fa367eea 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -17,7 +18,7 @@ public Rescaling(RescalingArgs args) : base(args) this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { scale = constant_op.constant(args.Scale, args.DType); offset = constant_op.constant(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs index 603e2b071..081966ad4 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs @@ -4,6 +4,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -19,7 +20,7 @@ public Resizing(ResizingArgs args) : base(args) this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); } diff --git a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs index aa3a92a49..ada1851ce 100644 --- a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs +++ b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs @@ -1,4 +1,5 @@ -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; @@ -15,7 +16,7 @@ public Dropout(DropoutArgs args) this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (training == null) training = false; diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs index 9ead15cb5..312854388 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -1,6 +1,8 @@ using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers.Reshaping { @@ -27,7 +29,7 @@ public override void build(KerasShapesWrapper input_shape) _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; if (output.rank != 3) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs index 087d59a14..4a5c6eabc 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -1,6 +1,7 @@ using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers.Reshaping { @@ -21,7 +22,7 @@ public override void build(KerasShapesWrapper input_shape) built = true; _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; if (output.rank != 4) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs index 04a1af600..83f86c6fc 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -1,6 +1,7 @@ using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers.Reshaping { @@ -21,7 +22,7 @@ public override void build(KerasShapesWrapper input_shape) _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; if (output.rank != 5) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index 539b5f624..a6192849d 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Framework; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; @@ -23,7 +24,7 @@ public Flatten(FlattenArgs args) _channels_first = args.DataFormat == "channels_first"; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (_channels_first) { diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index e391775c8..7fdb816bf 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -6,6 +6,7 @@ using static Tensorflow.Binding; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { public class Permute : Layer @@ -28,7 +29,7 @@ public override void build(KerasShapesWrapper input_shape) built = true; _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor outputs = inputs; return tf.transpose(outputs, new Axis(permute)); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index 92a772f34..4b3d30e29 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System; using System.Linq; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -19,7 +20,7 @@ public Reshape(ReshapeArgs args) this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var shapes = new List(); shapes.Add(array_ops.shape(inputs)[0]); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs index 8314151f6..223f33d4f 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs @@ -6,6 +6,7 @@ using Tensorflow.Keras.Utils; using static Tensorflow.Binding; using static Tensorflow.KerasApi; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -24,7 +25,7 @@ public UpSampling2D(UpSampling2DArgs args) : base(args) inputSpec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { return keras.backend.resize_images(inputs, size[0], size[1], diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs index 7c87100a2..3b37dac46 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs @@ -2,6 +2,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Layers @@ -26,7 +27,7 @@ public ZeroPadding2D(ZeroPadding2DArgs args, string data_format = null) this.input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { return keras.backend.spatial_2d_padding(inputs, padding: padding, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs new file mode 100644 index 000000000..21396853f --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -0,0 +1,85 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public abstract class DropoutRNNCellMixin: RnnCellBase + { + public float dropout; + public float recurrent_dropout; + // TODO(Rinne): deal with cache. + public DropoutRNNCellMixin(LayerArgs args): base(args) + { + + } + + public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + { + if (dropout == 0f) + return null; + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + // Get the recurrent dropout mask for RNN cell. + public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + { + if (dropout == 0f) + return null; + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1) + { + Tensors dropped_inputs() + { + DropoutArgs args = new DropoutArgs(); + args.Rate = rate; + var DropoutLayer = new Dropout(args); + var mask = DropoutLayer.Apply(ones, training: training); + return mask; + } + + if (count > 1) + { + Tensors results = new Tensors(); + for (int i = 0; i < count; i++) + { + results.Add(dropped_inputs()); + } + return results; + } + + return dropped_inputs(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index 59555e62b..1449c908e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -1,6 +1,7 @@ using System.Linq; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers.Rnn { @@ -26,9 +27,9 @@ public LSTM(LSTMArgs args) : .ToArray(); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { - return base.Call(inputs, state: state, training: training); + return base.Call(inputs, initial_state: state, training: training); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 310e80574..ab4cef124 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -1,53 +1,468 @@ -using System; +using OneOf; +using System; using System.Collections.Generic; +using System.Reflection; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Util; +using Tensorflow.Common.Extensions; +using System.Linq.Expressions; +using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn { - public class RNN : Layer + /// + /// Base class for recurrent layers. + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// + public class RNN : RnnBase { - private RNNArgs args; - private object input_spec = null; // or NoneValue?? - private object state_spec = null; - private object _states = null; - private object constants_spec = null; - private int _num_constants = 0; - protected IVariableV1 kernel; - protected IVariableV1 bias; - protected ILayer cell; + private RNNArgs _args; + private object _input_spec = null; // or NoneValue?? + private object _state_spec = null; + private Tensors _states = null; + private object _constants_spec = null; + private int _num_constants; + protected IVariableV1 _kernel; + protected IVariableV1 _bias; + protected IRnnCell _cell; + public RNN(RNNArgs args) : base(PreConstruct(args)) { - this.args = args; + _args = args; SupportsMasking = true; - // The input shape is unknown yet, it could have nested tensor inputs, and - // the input spec will be the list of specs for nested inputs, the structure - // of the input_spec will be the same as the input. + // if is StackedRnncell + _cell = args.Cell; + + // get input_shape + _args = PreConstruct(args); + + _num_constants = 0; + } + + // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) + // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape + public Tensors States + { + get + { + if (_states == null) + { + // CHECK(Rinne): check if this is correct. + var nested = _cell.StateSize.MapStructure(x => null); + _states = nested.AsNest().ToTensors(); + } + return _states; + } + set { _states = value; } + } + + private OneOf> compute_output_shape(Shape input_shape) + { + var batch = input_shape[0]; + var time_step = input_shape[1]; + if (_args.TimeMajor) + { + (batch, time_step) = (time_step, batch); + } + + // state_size is a array of ints or a positive integer + var state_size = _cell.StateSize.ToSingleShape(); + + // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor + Func _get_output_shape; + _get_output_shape = (flat_output_size) => + { + var output_dim = flat_output_size.as_int_list(); + Shape output_shape; + if (_args.ReturnSequences) + { + if (_args.TimeMajor) + { + output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); + } + else + { + output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); + + } + } + else + { + output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); + } + return output_shape; + }; + + Type type = _cell.GetType(); + PropertyInfo output_size_info = type.GetProperty("output_size"); + Shape output_shape; + if (output_size_info != null) + { + output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape()); + // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 + output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); + } + else + { + output_shape = _get_output_shape(state_size); + } + + if (_args.ReturnState) + { + Func _get_state_shape; + _get_state_shape = (flat_state) => + { + var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); + return new Shape(state_shape); + }; + var state_shape = _get_state_shape(state_size); + + return new List { output_shape, state_shape }; + } + else + { + return output_shape; + } - //if(stateful) - //{ - // if (ds_context.has_strategy()) // ds_context???? - // { - // throw new Exception("RNNs with stateful=True not yet supported with tf.distribute.Strategy"); - // } - //} + } + + private Tensors compute_mask(Tensors inputs, Tensors mask) + { + // Time step masks must be the same for each input. + // This is because the mask for an RNN is of size [batch, time_steps, 1], + // and specifies which time steps should be skipped, and a time step + // must be skipped for all inputs. + + mask = nest.flatten(mask)[0]; + var output_mask = _args.ReturnSequences ? mask : null; + if (_args.ReturnState) + { + var state_mask = new List(); + for (int i = 0; i < len(States); i++) + { + state_mask.Add(null); + } + return new List { output_mask }.concat(state_mask); + } + else + { + return output_mask; + } } public override void build(KerasShapesWrapper input_shape) { - if (!cell.Built) + object get_input_spec(Shape shape) + { + var input_spec_shape = shape.as_int_list(); + + var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1); + if (!_args.Stateful) + { + input_spec_shape[batch_index] = -1; + } + input_spec_shape[time_step_index] = -1; + return new InputSpec(shape: input_spec_shape); + } + + Shape get_step_input_shape(Shape shape) + { + + // return shape[1:] if self.time_major else (shape[0],) + shape[2:] + if (_args.TimeMajor) + { + return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); + } + else + { + return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); + } + + + } + + object get_state_spec(Shape shape) + { + var state_spec_shape = shape.as_int_list(); + // append bacth dim + state_spec_shape = new int[] { -1 }.concat(state_spec_shape); + return new InputSpec(shape: state_spec_shape); + + } + + // Check whether the input shape contains any nested shapes. It could be + // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from + // numpy inputs. + + + if (!_cell.Built) + { + _cell.build(input_shape); + } + } + + /// + /// + /// + /// + /// Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked + /// + /// List of initial state tensors to be passed to the first call of the cell + /// List of constant tensors to be passed to the cell at each timestep + /// + /// + /// + protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; + if(optional_args is not null && rnn_optional_args is null) + { + throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`"); + } + Tensors? constants = rnn_optional_args?.Constants; + Tensors? mask = rnn_optional_args?.Mask; + //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); + // 暂时先不接受ragged tensor + int row_length = 0; // TODO(Rinne): support this param. + bool is_ragged_input = false; + _validate_args_if_ragged(is_ragged_input, mask); + + (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); + + _maybe_reset_cell_dropout_mask(_cell); + if (_cell is StackedRNNCells) + { + var stack_cell = _cell as StackedRNNCells; + foreach (var cell in stack_cell.Cells) + { + _maybe_reset_cell_dropout_mask(cell); + } + } + + if (mask != null) + { + // Time step masks must be the same for each input. + mask = mask.Flatten().First(); + } + + Shape input_shape; + if (!inputs.IsSingle()) + { + // In the case of nested input, use the first element for shape check + // input_shape = nest.flatten(inputs)[0].shape; + // TODO(Wanglongzhi2001) + input_shape = inputs.Flatten().First().shape; + } + else + { + input_shape = inputs.shape; + } + + var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; + + if (_args.Unroll && timesteps != null) + { + throw new ValueError( + "Cannot unroll a RNN if the " + + "time dimension is undefined. \n" + + "- If using a Sequential model, " + + "specify the time dimension by passing " + + "an `input_shape` or `batch_input_shape` " + + "argument to your first layer. If your " + + "first layer is an Embedding, you can " + + "also use the `input_length` argument.\n" + + "- If using the functional API, specify " + + "the time dimension by passing a `shape` " + + "or `batch_shape` argument to your Input layer." + ); + } + + // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) + Func step; + bool is_tf_rnn_cell = _cell.IsTFRnnCell; + if (constants is not null) + { + if (!_cell.SupportOptionalArgs) + { + throw new ValueError( + $"RNN cell {_cell} does not support constants." + + $"Received: constants={constants}"); + } + + step = (inputs, states) => + { + constants = new Tensors(states.TakeLast(_num_constants)); + states = new Tensors(states.SkipLast(_num_constants)); + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; + var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); + // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? + return (output, new_states.Single); + }; + } + else + { + step = (inputs, states) => + { + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; + var (output, new_states) = _cell.Apply(inputs, states); + return (output, new_states.Single); + }; + } + + var (last_output, outputs, states) = keras.backend.rnn(step, + inputs, + initial_state, + constants: constants, + go_backwards: _args.GoBackwards, + mask: mask, + unroll: _args.Unroll, + input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), + time_major: _args.TimeMajor, + zero_output_for_mask: _args.ZeroOutputForMask, + return_all_outputs: _args.ReturnSequences); + + if (_args.Stateful) + { + throw new NotImplementedException("this argument havn't been developed."); + } + + Tensors output = new Tensors(); + if (_args.ReturnSequences) + { + // TODO(Rinne): add go_backwards parameter and revise the `row_length` param + output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false); + } + else + { + output = last_output; + } + + if (_args.ReturnState) + { + foreach (var state in states) + { + output.Add(state); + } + return output; + } + else + { + return output; + } + } + + public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) + { + RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; + if (optional_args is not null && rnn_optional_args is null) { - cell.build(input_shape); + throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`."); } + Tensors? constants = rnn_optional_args?.Constants; + (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants); + + if(initial_states is null && constants is null) + { + return base.Apply(inputs); + } + + // TODO(Rinne): implement it. + throw new NotImplementedException(); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) { - return base.Call(inputs, state, training); + if (inputs.Length > 1) + { + if (_num_constants != 0) + { + initial_state = new Tensors(inputs.Skip(1)); + } + else + { + initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants)); + constants = new Tensors(inputs.TakeLast(_num_constants)); + } + if (len(initial_state) == 0) + initial_state = null; + inputs = inputs[0]; + } + + if (_args.Stateful) + { + if (initial_state != null) + { + var tmp = new Tensor[] { }; + foreach (var s in nest.flatten(States)) + { + tmp.add(tf.math.count_nonzero((Tensor)s)); + } + var non_zero_count = tf.add_n(tmp); + //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); + if ((int)non_zero_count.numpy() > 0) + { + initial_state = States; + } + } + else + { + initial_state = States; + } + + } + else if (initial_state is null) + { + initial_state = get_initial_state(inputs); + } + + if (initial_state.Length != States.Length) + { + throw new ValueError( + $"Layer {this} expects {States.Length} state(s), " + + $"but it received {initial_state.Length} " + + $"initial state(s). Input received: {inputs}"); + } + + return (inputs, initial_state, constants); + } + + private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) + { + if (!is_ragged_input) + { + return; + } + + if (_args.Unroll) + { + throw new ValueError("The input received contains RaggedTensors and does " + + "not support unrolling. Disable unrolling by passing " + + "`unroll=False` in the RNN Layer constructor."); + } + if (mask != null) + { + throw new ValueError($"The mask that was passed in was {mask}, which " + + "cannot be applied to RaggedTensor inputs. Please " + + "make sure that there is no mask injected by upstream " + + "layers."); + } + + } + + void _maybe_reset_cell_dropout_mask(ILayer cell) + { + //if (cell is DropoutRNNCellMixin) + //{ + // cell.reset_dropout_mask(); + // cell.reset_recurrent_dropout_mask(); + //} } private static RNNArgs PreConstruct(RNNArgs args) @@ -77,60 +492,72 @@ private static RNNArgs PreConstruct(RNNArgs args) return args; } - public RNN New(LayerRnnCell cell, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool unroll = false, - bool time_major = false) - => new RNN(new RNNArgs - { - Cell = cell, - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - Unroll = unroll, - TimeMajor = time_major - }); - - public RNN New(IList cell, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool unroll = false, - bool time_major = false) - => new RNN(new RNNArgs - { - Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - Unroll = unroll, - TimeMajor = time_major - }); - - - protected Tensor get_initial_state(Tensor inputs) + public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) { - return _generate_zero_filled_state_for_cell(null, null); + throw new NotImplementedException(); } - Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) + // 好像不能cell不能传接口类型 + //public RNN New(IRnnArgCell cell, + // bool return_sequences = false, + // bool return_state = false, + // bool go_backwards = false, + // bool stateful = false, + // bool unroll = false, + // bool time_major = false) + // => new RNN(new RNNArgs + // { + // Cell = cell, + // ReturnSequences = return_sequences, + // ReturnState = return_state, + // GoBackwards = go_backwards, + // Stateful = stateful, + // Unroll = unroll, + // TimeMajor = time_major + // }); + + //public RNN New(List cell, + // bool return_sequences = false, + // bool return_state = false, + // bool go_backwards = false, + // bool stateful = false, + // bool unroll = false, + // bool time_major = false) + // => new RNN(new RNNArgs + // { + // Cell = cell, + // ReturnSequences = return_sequences, + // ReturnState = return_state, + // GoBackwards = go_backwards, + // Stateful = stateful, + // Unroll = unroll, + // TimeMajor = time_major + // }); + + + protected Tensors get_initial_state(Tensors inputs) { - throw new NotImplementedException(""); + var input = inputs[0]; + var input_shape = input.shape; + var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; + var dtype = input.dtype; + Tensors init_state; + if (_cell is RnnCellBase rnn_base_cell) + { + init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); + } + else + { + init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); + } + + return init_state; } // Check whether the state_size contains multiple states. - public static bool _is_multiple_state(object state_size) + public static bool is_multiple_state(GeneralizedTensorShape state_size) { - var myIndexerProperty = state_size.GetType().GetProperty("Item"); - return myIndexerProperty != null - && myIndexerProperty.GetIndexParameters().Length == 1 - && !(state_size.GetType() == typeof(Shape)); + return state_size.Shapes.Length > 1; } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs new file mode 100644 index 000000000..018b17780 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public abstract class RnnBase: Layer + { + public RnnBase(LayerArgs args): base(args) { } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs new file mode 100644 index 000000000..751312e5d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public abstract class RnnCellBase: Layer, IRnnCell + { + public RnnCellBase(LayerArgs args) : base(args) { } + public abstract GeneralizedTensorShape StateSize { get; } + public abstract GeneralizedTensorShape OutputSize { get; } + public abstract bool IsTFRnnCell { get; } + public abstract bool SupportOptionalArgs { get; } + public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype) + { + return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index 2d7aab70e..22d0e2770 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -10,18 +10,36 @@ namespace Tensorflow.Keras.Layers.Rnn public class SimpleRNN : RNN { SimpleRNNArgs args; - public SimpleRNN(SimpleRNNArgs args) : base(args) + public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args)) { this.args = args; } + private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args) + { + args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs() + { + Units = args.Units, + Activation = args.Activation, + UseBias = args.UseBias, + KernelInitializer = args.KernelInitializer, + RecurrentInitializer = args.RecurrentInitializer, + BiasInitializer = args.BiasInitializer, + Dropout = args.Dropout, + RecurrentDropout = args.RecurrentDropout, + DType = args.DType, + Trainable = args.Trainable, + }); + return args; + } + public override void build(KerasShapesWrapper input_shape) { var single_shape = input_shape.ToSingleShape(); var input_dim = single_shape[-1]; _buildInputShape = input_shape; - kernel = add_weight("kernel", (single_shape[-1], args.Units), + _kernel = add_weight("kernel", (single_shape[-1], args.Units), initializer: args.KernelInitializer //regularizer = self.kernel_regularizer, //constraint = self.kernel_constraint, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 46061b211..f0b2ed4d7 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -4,47 +4,114 @@ using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; +using Tensorflow.Common.Extensions; namespace Tensorflow.Keras.Layers.Rnn { - public class SimpleRNNCell : Layer + /// + /// Cell class for SimpleRNN. + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// This class processes one step within the whole time sequence input, whereas + /// `tf.keras.layer.SimpleRNN` processes the whole sequence. + /// + public class SimpleRNNCell : DropoutRNNCellMixin { - SimpleRNNArgs args; - IVariableV1 kernel; - IVariableV1 recurrent_kernel; - IVariableV1 bias; + SimpleRNNCellArgs _args; + IVariableV1 _kernel; + IVariableV1 _recurrent_kernel; + IVariableV1 _bias; + GeneralizedTensorShape _state_size; + GeneralizedTensorShape _output_size; - public SimpleRNNCell(SimpleRNNArgs args) : base(args) + public override GeneralizedTensorShape StateSize => _state_size; + public override GeneralizedTensorShape OutputSize => _output_size; + public override bool IsTFRnnCell => true; + public override bool SupportOptionalArgs => false; + + public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) { - this.args = args; + this._args = args; + if (args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); + this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); + _state_size = new GeneralizedTensorShape(args.Units); + _output_size = new GeneralizedTensorShape(args.Units); } public override void build(KerasShapesWrapper input_shape) { + // TODO(Rinne): add the cache. var single_shape = input_shape.ToSingleShape(); var input_dim = single_shape[-1]; - kernel = add_weight("kernel", (single_shape[-1], args.Units), - initializer: args.KernelInitializer + _kernel = add_weight("kernel", (single_shape[-1], _args.Units), + initializer: _args.KernelInitializer ); - recurrent_kernel = add_weight("recurrent_kernel", (args.Units, args.Units), - initializer: args.RecurrentInitializer + _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units), + initializer: _args.RecurrentInitializer ); - if (args.UseBias) + if (_args.UseBias) { - bias = add_weight("bias", (args.Units), - initializer: args.BiasInitializer + _bias = add_weight("bias", (_args.Units), + initializer: _args.BiasInitializer ); } built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + // TODO(Rinne): revise the trining param (with refactoring of the framework) + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) { - return base.Call(inputs, state, training); + // TODO(Rinne): check if it will have multiple tensors when not nested. + Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; + var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); + var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); + + Tensor h; + if (dp_mask != null) + { + h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); + } + else + { + h = math_ops.matmul(inputs, _kernel.AsTensor()); + } + + if (_bias != null) + { + h = tf.nn.bias_add(h, _bias); + } + + if (rec_dp_mask != null) + { + prev_output = math_ops.multiply(prev_output, rec_dp_mask); + } + + Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); + + if (_args.Activation != null) + { + output = _args.Activation.Apply(output); + } + if (Nest.IsNested(states)) + { + return new Nest(new List> { + new Nest(new List> { new Nest(output) }), new Nest(output) }) + .ToTensors(); + } + else + { + return new Tensors(output, output); + } } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 20962df1f..0b92fd3cf 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; @@ -8,7 +9,7 @@ namespace Tensorflow.Keras.Layers.Rnn { - public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell + public class StackedRNNCells : Layer, IRnnCell { public IList Cells { get; set; } public bool reverse_state_order; @@ -51,7 +52,7 @@ public object output_size { return lastCell.output_size; } - else if (RNN._is_multiple_state(lastCell.state_size)) + else if (RNN.is_multiple_state(lastCell.StateSize)) { // return ((dynamic)Cells[-1].state_size)[0]; throw new NotImplementedException(""); @@ -162,5 +163,14 @@ public void from_config() // deserialize_layer(cell_config, custom_objects = custom_objects)) // return cls(cells, **config) } + + public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) + { + throw new NotImplementedException(); + } + public GeneralizedTensorShape StateSize => throw new NotImplementedException(); + public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); + public bool IsTFRnnCell => throw new NotImplementedException(); + public bool SupportOptionalArgs => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index 1ac4a277c..6dfec3196 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -10,6 +10,7 @@ using static Tensorflow.Binding; using Tensorflow.Functions; using System.Threading; +using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers { @@ -34,7 +35,7 @@ public TensorFlowOpLayer(TensorFlowOpLayerArgs args) built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { if (tf.Context.executing_eagerly()) return DeFunCall(inputs); diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs index be6a49ec5..3c2f8a7be 100644 --- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -304,7 +304,7 @@ private static Tensor _filter_top_k(Tensor x, int k) var NEG_INF = -1e10; var (_, top_k_idx) = tf.math.top_k(x, k, sorted: false); var top_k_mask = tf.reduce_sum( - tf.one_hot(top_k_idx, (int)x.shape[-1], axis: -1), axis: -2); + tf.one_hot(top_k_idx.Single, (int)x.shape[-1], axis: -1), axis: -2); return x * top_k_mask + NEG_INF * (1 - top_k_mask); } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index fa19987b1..4acae4265 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -129,7 +129,7 @@ public IDatasetV2 timeseries_dataset_from_array(Tensor data, int sequence_length var indices = z.map(m => { var (i, positions) = m; - return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); + return tf.range(positions.Single[i], positions.Single[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); }, num_parallel_calls: -1); var dataset = sequences_from_indices(data, indices, start_index, end_index); diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index a26879e0c..396ad20eb 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -8,7 +8,7 @@ using System.Linq; using System.Reflection; using System.Text.RegularExpressions; -using Tensorflow.Extensions; +using Tensorflow.Common.Extensions; using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; diff --git a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs new file mode 100644 index 000000000..3109eb77b --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs @@ -0,0 +1,93 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Common.Extensions; + +namespace Tensorflow.Keras.Utils +{ + internal static class RnnUtils + { + internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) + { + Func create_zeros; + create_zeros = (GeneralizedTensorShape unnested_state_size) => + { + var flat_dims = unnested_state_size.ToSingleShape().dims; + var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); + return array_ops.zeros(new Shape(init_state_size), dtype: dtype); + }; + + // TODO(Rinne): map structure with nested tensors. + if(state_size.Shapes.Length > 1) + { + return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s)))); + } + else + { + return create_zeros(state_size); + } + + } + + internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) + { + if (inputs != null) + { + batch_size = inputs.shape[0]; + dtype = inputs.dtype; + } + return generate_zero_filled_state(batch_size, cell.StateSize, dtype); + } + + /// + /// Standardizes `__call__` to a single list of tensor inputs. + /// + /// When running a model loaded from a file, the input tensors + /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part + /// of `inputs` instead of by the dedicated keyword arguments.This method + /// makes sure the arguments are separated and that `initial_state` and + /// `constants` are lists of tensors(or None). + /// + /// Tensor or list/tuple of tensors. which may include constants + /// and initial states.In that case `num_constant` must be specified. + /// Tensor or list of tensors or None, initial states. + /// Tensor or list of tensors or None, constant tensors. + /// Expected number of constants (if constants are passed as + /// part of the `inputs` list. + /// + internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants) + { + if(inputs.Length > 1) + { + // There are several situations here: + // In the graph mode, __call__ will be only called once. The initial_state + // and constants could be in inputs (from file loading). + // In the eager mode, __call__ will be called twice, once during + // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be + // model.fit/train_on_batch/predict with real np data. In the second case, + // the inputs will contain initial_state and constants as eager tensor. + // + // For either case, the real input is the first item in the list, which + // could be a nested structure itself. Then followed by initial_states, which + // could be a list of items, or list of list if the initial_state is complex + // structure, and finally followed by constants which is a flat list. + Debug.Assert(initial_state is null && constants is null); + if(num_constants > 0) + { + constants = inputs.TakeLast(num_constants).ToTensors(); + inputs = inputs.SkipLast(num_constants).ToTensors(); + } + if(inputs.Length > 1) + { + initial_state = inputs.Skip(1).ToTensors(); + inputs = inputs.Take(1).ToTensors(); + } + } + + return (inputs, initial_state, constants); + } + } +} diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs index b9ca949bc..20d9851b1 100644 --- a/src/TensorflowNET.Hub/KerasLayer.cs +++ b/src/TensorflowNET.Hub/KerasLayer.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.Engine; using Tensorflow.Train; using Tensorflow.Training; @@ -89,7 +90,7 @@ private void _setup_layer(bool trainable = false) } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null) { _check_trainability(); diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 3de337469..f4980b82d 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -144,17 +144,6 @@ public void EinsumDense() Assert.AreEqual(expected_output, actual_output); } - [TestMethod, Ignore("WIP")] - public void SimpleRNN() - { - var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); - /*var simple_rnn = keras.layers.SimpleRNN(4); - var output = simple_rnn.Apply(inputs); - Assert.AreEqual((32, 4), output.shape);*/ - var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); - var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); - } - [TestMethod] public void Resizing() { diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs new file mode 100644 index 000000000..55663d41c --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -0,0 +1,28 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class Rnn + { + [TestMethod] + public void SimpleRNN() + { + var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); + /*var simple_rnn = keras.layers.SimpleRNN(4); + var output = simple_rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape);*/ + var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); + var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); + Console.WriteLine(whole_sequence_output); + Console.WriteLine(final_state); + } + } +} diff --git a/tools/TensorFlowNET.Console/SimpleRnnTest.cs b/tools/TensorFlowNET.Console/SimpleRnnTest.cs index 9769eb655..ae6ebb8a8 100644 --- a/tools/TensorFlowNET.Console/SimpleRnnTest.cs +++ b/tools/TensorFlowNET.Console/SimpleRnnTest.cs @@ -20,7 +20,7 @@ public void Run() // whole_sequence_output has shape `[32, 10, 4]`. // final_state has shape `[32, 4]`. - var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); + var (whole_sequence_output, final_states) = simple_rnn.Apply(inputs); } } }