diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj index fb145e8df..07118fd55 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj @@ -8,6 +8,7 @@ + diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs new file mode 100644 index 000000000..67cb0c401 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading.Tasks; +using BotSharp.Plugin.RoutingSpeeder.Providers; +using BotSharp.Plugin.RoutingSpeeder.Providers.Models; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; + +namespace BotSharp.Plugin.RoutingSpeeder.Controllers; + +[AllowAnonymous] +public class RoutingSpeederController : ControllerBase +{ + private readonly IServiceProvider _service; + public RoutingSpeederController(IServiceProvider service) + { + _service = service; + } + + [HttpPost("/routing-speeder/classifier/train")] + public IActionResult TrainIntentClassifier(TrainingParams trainingParams) + { + var intentClassifier = _service.GetRequiredService(); + intentClassifier.InitClassifer(trainingParams.Inference); + intentClassifier.Train(trainingParams); + return Ok(intentClassifier.Labels); + } + +} diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs index 3c3a33f25..440bebf0b 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs @@ -11,31 +11,46 @@ using System.Text.RegularExpressions; using BotSharp.Plugin.RoutingSpeeder.Settings; using BotSharp.Abstraction.MLTasks; +using BotSharp.Abstraction.Knowledges.Settings; using BotSharp.Plugin.RoutingSpeeder.Providers.Models; using Microsoft.Extensions.DependencyInjection; using System.Linq; using Tensorflow.Keras; -using BotSharp.Abstraction.Knowledges.Settings; using System.Numerics; using Newtonsoft.Json; using Tensorflow.Keras.Layers; using BotSharp.Abstraction.Agents; +using BotSharp.Abstraction.Knowledges; namespace BotSharp.Plugin.RoutingSpeeder.Providers; public class IntentClassifier { private readonly IServiceProvider _services; + private KnowledgeBaseSettings _knowledgeBaseSettings; Model _model; public Model model => _model; private bool _isModelReady; public bool isModelReady => _isModelReady; private ClassifierSetting _settings; - public IntentClassifier(IServiceProvider services, ClassifierSetting settings) + private string[] _labels; + + public string[] Labels => GetLabels(); + + private int _numLabels + { + get + { + return Labels.Length; + } + } + + public IntentClassifier(IServiceProvider services, ClassifierSetting settings, KnowledgeBaseSettings knowledgeBaseSettings) { _services = services; _settings = settings; + _knowledgeBaseSettings = knowledgeBaseSettings; } private void Reset() @@ -50,17 +65,16 @@ private void Build() { return; } - - var vector = _services.GetRequiredService(); - var labels = GetLabels(); + var vector = _services.GetServices() + .FirstOrDefault(x => x.GetType().FullName.EndsWith(_knowledgeBaseSettings.TextEmbedding)); var layers = new List { keras.layers.InputLayer((vector.Dimension), name: "Input"), keras.layers.Dense(256, activation:"relu"), keras.layers.Dense(256, activation:"relu"), - keras.layers.Dense(labels.Length, activation: keras.activations.Softmax) + keras.layers.Dense(_numLabels, activation: keras.activations.Softmax) }; _model = keras.Sequential(layers); @@ -90,7 +104,7 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams) var callbacks = new List() { earlyStop }; - var weights = LoadWeights(); + var weights = LoadWeights(trainingParams.Inference); _model.fit(x, y, batch_size: trainingParams.BatchSize, @@ -104,42 +118,27 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams) _isModelReady = true; } - public string LoadWeights() + public string LoadWeights(bool inference = true) { var agentService = _services.CreateScope().ServiceProvider.GetRequiredService(); var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5"); - if (File.Exists(weightsFile)) + + if (File.Exists(weightsFile) && inference) { _model.load_weights(weightsFile); _isModelReady = true; Console.WriteLine($"Successfully load the weights!"); + } else { - Console.WriteLine("No available weights."); + var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local"; + Console.WriteLine(logInfo); } return weightsFile; } - public (NDArray x, NDArray y) Vectorize(List items) - { - var vector = _services.GetRequiredService(); - - var x = np.zeros((items.Count, vector.Dimension), dtype: np.float32); - var y = np.zeros((items.Count, 1), dtype: np.float32); - - for (int i = 0; i < items.Count; i++) - { - x[i] = vector.GetVector(TextClean(items[i].text)); - if (_settings.LabelMappingDict.ContainsKey(items[i].label)) - { - y[i] = _settings.LabelMappingDict[items[i].label]; - } - } - return (x, y); - } - public NDArray GetTextEmbedding(string text) { var knowledgeSettings = _services.GetRequiredService(); @@ -164,10 +163,10 @@ public NDArray GetTextEmbedding(string text) var vector = _services.GetRequiredService(); - var vectorList = new List(); var labelList = new List(); + foreach (var filePath in GetFiles()) { var texts = File.ReadAllLines(filePath, Encoding.UTF8).Select(x => TextClean(x)).ToList(); @@ -192,19 +191,24 @@ public NDArray GetTextEmbedding(string text) return (x, y); } - public string[] GetFiles() + public string[] GetFiles(string prefix = "intent") { var agentService = _services.CreateScope().ServiceProvider.GetRequiredService(); string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR); - return Directory.GetFiles(rootDirectory).OrderBy(x => x).ToArray(); + return Directory.GetFiles(rootDirectory).Where(x => Path.GetFileNameWithoutExtension(x).StartsWith(prefix)).OrderBy(x => x).ToArray(); } public string[] GetLabels() { - var agentService = _services.CreateScope().ServiceProvider.GetRequiredService(); - string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME); - var labelText = File.ReadAllLines(rootDirectory); - return labelText.OrderBy(x => x).ToArray(); + if (_labels == null) + { + var agentService = _services.CreateScope().ServiceProvider.GetRequiredService(); + string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME); + var labelText = File.ReadAllLines(rootDirectory); + _labels = labelText.OrderBy(x => x).ToArray(); + } + + return _labels; } public string TextClean(string text) @@ -235,24 +239,22 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f) return string.Empty; } - var prediction = GetLabels()[probLabel[0]]; + var prediction = _labels[probLabel[0]]; return prediction; } - public void InitClassifer() + public void InitClassifer(bool inference = true) { Reset(); Build(); - LoadWeights(); + LoadWeights(inference); } - public void Train() + public void Train(TrainingParams trainingParams) { - var trainingParams = new TrainingParams(); Reset(); (var x, var y) = PrepareLoadData(); Build(); Fit(x, y, trainingParams); - } } diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs index f3c822ac1..4cd9829c9 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs @@ -10,4 +10,5 @@ public class TrainingParams public int Epochs { get; set; } = 10; public int BatchSize { get; set; } = 16; public float LearningRate { get; set; } = 1.0e-4f; + public bool Inference { get; set; } = false; } diff --git a/src/WebStarter/data/models/intent-classifier.h5 b/src/WebStarter/data/models/intent-classifier.h5 index 13f2ebeda..e4be7ee59 100644 Binary files a/src/WebStarter/data/models/intent-classifier.h5 and b/src/WebStarter/data/models/intent-classifier.h5 differ