From f3b3d8be65f8d037dd456d6380bb93d2e888b53c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CWanglongzhi2001=E2=80=9D?= <“583087864@qq.com”> Date: Fri, 28 Jul 2023 12:42:11 +0800 Subject: [PATCH] fix: add the momentum parameter's implemention of SGD --- src/TensorFlowNET.Core/Keras/IOptimizerApi.cs | 2 +- .../Training/gen_training_ops.cs | 4 ++++ .../Optimizers/OptimizerApi.cs | 4 ++-- src/TensorFlowNET.Keras/Optimizers/SGD.cs | 19 ++++++++++++++++++- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs b/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs index d0d3a74f1..19e3a7b8c 100644 --- a/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs +++ b/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs @@ -63,6 +63,6 @@ IOptimizer RMSprop(float learning_rate = 0.001f, bool centered = false, string name = "RMSprop"); - IOptimizer SGD(float learning_rate); + IOptimizer SGD(float learning_rate, float momentum); } } diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index abe85a141..df7dd9e65 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -51,5 +51,9 @@ public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tenso public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) => tf.Context.ExecuteOp("ResourceApplyGradientDescent", name, new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking })); + + public static Tensor resource_apply_keras_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool use_locking = false, bool use_nesterov = false, string name = null) + => tf.Context.ExecuteOp("ResourceApplyKerasMomentum", name, + new ExecuteOpArgs(var, accum, lr, grad, momentum).SetAttributes(new { use_locking, use_nesterov })); } } diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs index 280694268..affd43a4f 100644 --- a/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs @@ -71,7 +71,7 @@ public IOptimizer RMSprop(float learning_rate = 0.001f, Name = name }); - public IOptimizer SGD(float learning_rate) - => new SGD(learning_rate); + public IOptimizer SGD(float learning_rate, float momentum) + => new SGD(learning_rate, momentum); } } diff --git a/src/TensorFlowNET.Keras/Optimizers/SGD.cs b/src/TensorFlowNET.Keras/Optimizers/SGD.cs index f97f4b15f..1d9ceb810 100644 --- a/src/TensorFlowNET.Keras/Optimizers/SGD.cs +++ b/src/TensorFlowNET.Keras/Optimizers/SGD.cs @@ -22,6 +22,8 @@ public SGD(float learning_rate, _set_hyper("decay", decay); _momentum = momentum > 0; + if (momentum < 0 || momentum > 1) + throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}."); _set_hyper("momentum", momentum); @@ -30,6 +32,13 @@ public SGD(float learning_rate, #pragma warning restore CS1717 // Assignment made to same variable } + protected override void _create_slots(IVariableV1[] var_list) + { + if (_momentum) + foreach (var var in var_list) + add_slot(var, "momentum"); + } + protected override void _prepare_local(DeviceDType device_dtype, Dictionary> _apply_state) { @@ -43,7 +52,15 @@ protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, { if (_momentum) { - throw new NotImplementedException("_resource_apply_dense"); + var momentum_var = get_slot(var, "momentum"); + return gen_training_ops.resource_apply_keras_momentum( + var.Handle, + momentum_var.Handle, + _get_hyper("learning_rate", var.dtype), + grad, + _get_hyper("momentum", var.dtype), + use_locking: _use_locking, + use_nesterov: nesterov); } var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype());