diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj
index fb145e8df..07118fd55 100644
--- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj
+++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/BotSharp.Plugin.RoutingSpeeder.csproj
@@ -8,6 +8,7 @@
+
diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs
new file mode 100644
index 000000000..67cb0c401
--- /dev/null
+++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs
@@ -0,0 +1,32 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+using System.Threading.Tasks;
+using BotSharp.Plugin.RoutingSpeeder.Providers;
+using BotSharp.Plugin.RoutingSpeeder.Providers.Models;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace BotSharp.Plugin.RoutingSpeeder.Controllers;
+
+[AllowAnonymous]
+public class RoutingSpeederController : ControllerBase
+{
+ private readonly IServiceProvider _service;
+ public RoutingSpeederController(IServiceProvider service)
+ {
+ _service = service;
+ }
+
+ [HttpPost("/routing-speeder/classifier/train")]
+ public IActionResult TrainIntentClassifier(TrainingParams trainingParams)
+ {
+ var intentClassifier = _service.GetRequiredService();
+ intentClassifier.InitClassifer(trainingParams.Inference);
+ intentClassifier.Train(trainingParams);
+ return Ok(intentClassifier.Labels);
+ }
+
+}
diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs
index 3c3a33f25..440bebf0b 100644
--- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs
+++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs
@@ -11,31 +11,46 @@
using System.Text.RegularExpressions;
using BotSharp.Plugin.RoutingSpeeder.Settings;
using BotSharp.Abstraction.MLTasks;
+using BotSharp.Abstraction.Knowledges.Settings;
using BotSharp.Plugin.RoutingSpeeder.Providers.Models;
using Microsoft.Extensions.DependencyInjection;
using System.Linq;
using Tensorflow.Keras;
-using BotSharp.Abstraction.Knowledges.Settings;
using System.Numerics;
using Newtonsoft.Json;
using Tensorflow.Keras.Layers;
using BotSharp.Abstraction.Agents;
+using BotSharp.Abstraction.Knowledges;
namespace BotSharp.Plugin.RoutingSpeeder.Providers;
public class IntentClassifier
{
private readonly IServiceProvider _services;
+ private KnowledgeBaseSettings _knowledgeBaseSettings;
Model _model;
public Model model => _model;
private bool _isModelReady;
public bool isModelReady => _isModelReady;
private ClassifierSetting _settings;
- public IntentClassifier(IServiceProvider services, ClassifierSetting settings)
+ private string[] _labels;
+
+ public string[] Labels => GetLabels();
+
+ private int _numLabels
+ {
+ get
+ {
+ return Labels.Length;
+ }
+ }
+
+ public IntentClassifier(IServiceProvider services, ClassifierSetting settings, KnowledgeBaseSettings knowledgeBaseSettings)
{
_services = services;
_settings = settings;
+ _knowledgeBaseSettings = knowledgeBaseSettings;
}
private void Reset()
@@ -50,17 +65,16 @@ private void Build()
{
return;
}
-
- var vector = _services.GetRequiredService();
- var labels = GetLabels();
+ var vector = _services.GetServices()
+ .FirstOrDefault(x => x.GetType().FullName.EndsWith(_knowledgeBaseSettings.TextEmbedding));
var layers = new List
{
keras.layers.InputLayer((vector.Dimension), name: "Input"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(256, activation:"relu"),
- keras.layers.Dense(labels.Length, activation: keras.activations.Softmax)
+ keras.layers.Dense(_numLabels, activation: keras.activations.Softmax)
};
_model = keras.Sequential(layers);
@@ -90,7 +104,7 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
var callbacks = new List() { earlyStop };
- var weights = LoadWeights();
+ var weights = LoadWeights(trainingParams.Inference);
_model.fit(x, y,
batch_size: trainingParams.BatchSize,
@@ -104,42 +118,27 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
_isModelReady = true;
}
- public string LoadWeights()
+ public string LoadWeights(bool inference = true)
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService();
var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5");
- if (File.Exists(weightsFile))
+
+ if (File.Exists(weightsFile) && inference)
{
_model.load_weights(weightsFile);
_isModelReady = true;
Console.WriteLine($"Successfully load the weights!");
+
}
else
{
- Console.WriteLine("No available weights.");
+ var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local";
+ Console.WriteLine(logInfo);
}
return weightsFile;
}
- public (NDArray x, NDArray y) Vectorize(List items)
- {
- var vector = _services.GetRequiredService();
-
- var x = np.zeros((items.Count, vector.Dimension), dtype: np.float32);
- var y = np.zeros((items.Count, 1), dtype: np.float32);
-
- for (int i = 0; i < items.Count; i++)
- {
- x[i] = vector.GetVector(TextClean(items[i].text));
- if (_settings.LabelMappingDict.ContainsKey(items[i].label))
- {
- y[i] = _settings.LabelMappingDict[items[i].label];
- }
- }
- return (x, y);
- }
-
public NDArray GetTextEmbedding(string text)
{
var knowledgeSettings = _services.GetRequiredService();
@@ -164,10 +163,10 @@ public NDArray GetTextEmbedding(string text)
var vector = _services.GetRequiredService();
-
var vectorList = new List();
var labelList = new List();
+
foreach (var filePath in GetFiles())
{
var texts = File.ReadAllLines(filePath, Encoding.UTF8).Select(x => TextClean(x)).ToList();
@@ -192,19 +191,24 @@ public NDArray GetTextEmbedding(string text)
return (x, y);
}
- public string[] GetFiles()
+ public string[] GetFiles(string prefix = "intent")
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService();
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR);
- return Directory.GetFiles(rootDirectory).OrderBy(x => x).ToArray();
+ return Directory.GetFiles(rootDirectory).Where(x => Path.GetFileNameWithoutExtension(x).StartsWith(prefix)).OrderBy(x => x).ToArray();
}
public string[] GetLabels()
{
- var agentService = _services.CreateScope().ServiceProvider.GetRequiredService();
- string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME);
- var labelText = File.ReadAllLines(rootDirectory);
- return labelText.OrderBy(x => x).ToArray();
+ if (_labels == null)
+ {
+ var agentService = _services.CreateScope().ServiceProvider.GetRequiredService();
+ string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME);
+ var labelText = File.ReadAllLines(rootDirectory);
+ _labels = labelText.OrderBy(x => x).ToArray();
+ }
+
+ return _labels;
}
public string TextClean(string text)
@@ -235,24 +239,22 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
return string.Empty;
}
- var prediction = GetLabels()[probLabel[0]];
+ var prediction = _labels[probLabel[0]];
return prediction;
}
- public void InitClassifer()
+ public void InitClassifer(bool inference = true)
{
Reset();
Build();
- LoadWeights();
+ LoadWeights(inference);
}
- public void Train()
+ public void Train(TrainingParams trainingParams)
{
- var trainingParams = new TrainingParams();
Reset();
(var x, var y) = PrepareLoadData();
Build();
Fit(x, y, trainingParams);
-
}
}
diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs
index f3c822ac1..4cd9829c9 100644
--- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs
+++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/TrainingParams.cs
@@ -10,4 +10,5 @@ public class TrainingParams
public int Epochs { get; set; } = 10;
public int BatchSize { get; set; } = 16;
public float LearningRate { get; set; } = 1.0e-4f;
+ public bool Inference { get; set; } = false;
}
diff --git a/src/WebStarter/data/models/intent-classifier.h5 b/src/WebStarter/data/models/intent-classifier.h5
index 13f2ebeda..e4be7ee59 100644
Binary files a/src/WebStarter/data/models/intent-classifier.h5 and b/src/WebStarter/data/models/intent-classifier.h5 differ