Skip to content

solve RoutingSpeeder path issue #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Diagnostics;
using System.Text;
using System.Threading.Tasks;
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Plugin.RoutingSpeeder.Providers;
using BotSharp.Plugin.RoutingSpeeder.Providers.Models;
using Microsoft.AspNetCore.Authorization;
Expand All @@ -24,9 +25,16 @@ public RoutingSpeederController(IServiceProvider service)
public IActionResult TrainIntentClassifier(TrainingParams trainingParams)
{
var intentClassifier = _service.GetRequiredService<IntentClassifier>();
intentClassifier.InitClassifer(trainingParams.Inference);
intentClassifier.Train(trainingParams);
return Ok(intentClassifier.Labels);
}

[HttpPost("/routing-speeder/classifier/inference")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this API for train or inference?
The function name doesn't match with endpoint.

public IActionResult InferenceIntentClassifier([FromBody] DialoguePredictionModel message)
{
var intentClassifier = _service.GetRequiredService<IntentClassifier>();
var vector = intentClassifier.GetTextEmbedding(message.Text);
var predText = intentClassifier.Predict(vector);
return Ok(predText);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,33 @@
using System.Linq;
using Tensorflow.Keras;
using BotSharp.Abstraction.Agents;
using Microsoft.Extensions.Logging;

namespace BotSharp.Plugin.RoutingSpeeder.Providers;

public class IntentClassifier
{
private readonly IServiceProvider _services;
private readonly ILogger _logger;
private KnowledgeBaseSettings _knowledgeBaseSettings;
Model _model;
public Model model => _model;
private bool _isModelReady;
public bool isModelReady => _isModelReady;
private ClassifierSetting _settings;
private bool _inferenceMode = true;
private string[] _labels;
public string[] Labels => GetLabels();
private int _numLabels
{
get
{
return Labels.Length;
}
}
public string[] Labels => _labels == null ? GetLabels() : _labels;

public IntentClassifier(IServiceProvider services, ClassifierSetting settings, KnowledgeBaseSettings knowledgeBaseSettings)
public IntentClassifier(IServiceProvider services,
ClassifierSetting settings,
KnowledgeBaseSettings knowledgeBaseSettings,
ILogger logger)
{
_services = services;
_settings = settings;
_knowledgeBaseSettings = knowledgeBaseSettings;
_logger = logger;
}

private void Reset()
Expand All @@ -65,15 +65,14 @@ private void Build()
keras.layers.InputLayer((vector.Dimension), name: "Input"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(_numLabels, activation: keras.activations.Softmax)
keras.layers.Dense(GetLabels().Length, activation: keras.activations.Softmax)
};
_model = keras.Sequential(layers);

#if DEBUG
Console.WriteLine();
_model.summary();
#endif
_isModelReady = true;
}

private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
Expand All @@ -97,7 +96,7 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
earlyStop
};

var weights = LoadWeights(trainingParams.Inference);
var weights = LoadWeights();

_model.fit(x, y,
batch_size: trainingParams.BatchSize,
Expand All @@ -110,24 +109,25 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
_isModelReady = true;
}

public string LoadWeights(bool inference = true)
public string LoadWeights()
{
var agentService = _services.CreateScope()
.ServiceProvider
.GetRequiredService<IAgentService>();

var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5");
var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.WEIGHT_FILE_NAME);

if (File.Exists(weightsFile) && inference)
if (File.Exists(weightsFile) && _inferenceMode)
{
_model.load_weights(weightsFile);
_isModelReady = true;
Console.WriteLine($"Successfully load the weights!");
}
else
{
var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local";
Console.WriteLine(logInfo);
var logInfo = _inferenceMode ? "No available weights." : "Will implement model training process and write trained weights into local";
_isModelReady = false;
_logger.LogInformation(logInfo);
}

return weightsFile;
Expand Down Expand Up @@ -159,7 +159,14 @@ public NDArray GetTextEmbedding(string text)

if (!Directory.Exists(rootDirectory))
{
throw new Exception($"No training data found! Please put training data in this path: {rootDirectory}");
Directory.CreateDirectory(rootDirectory);
}

int numFiles = Directory.GetFiles(rootDirectory).Length;

if (numFiles == 0)
{
throw new Exception($"No dialogue data found in {rootDirectory} folder! Please put dialogue data in this path: {rootDirectory}");
}

// Do embedding and store results
Expand Down Expand Up @@ -214,25 +221,32 @@ public string[] GetFiles(string prefix = "")

public string[] GetLabels()
{
if (_labels == null)
var agentService = _services.CreateScope()
.ServiceProvider
.GetRequiredService<IAgentService>();
string labelPath = Path.Combine(
agentService.GetDataDir(),
_settings.MODEL_DIR,
_settings.LABEL_FILE_NAME);

if (_inferenceMode)
{
var agentService = _services.CreateScope()
.ServiceProvider
.GetRequiredService<IAgentService>();

string[] labels = GetFiles()
if (_labels == null)
{
if (!File.Exists(labelPath))
{
throw new Exception($"Label file doesn't exist. Please training model first or move label.txt to {labelPath}");
}
_labels = File.ReadAllLines(labelPath);
}
}
else
{
_labels = GetFiles()
.Select(x => Path.GetFileName(x).Split(".")[^2])
.OrderBy(x => x)
.ToArray();

string writePath = Path.Combine(
agentService.GetDataDir(),
_settings.MODEL_DIR,
_settings.LABEL_FILE_NAME);

_labels = labels.OrderBy(x => x).ToArray();

// Write labels into the local txt file
File.WriteAllLines(writePath, _labels);
File.WriteAllLines(labelPath, _labels);
}
return _labels;
}
Expand All @@ -248,24 +262,25 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
var prob = _model.predict(vector).numpy();
var probLabel = tf.arg_max(prob, -1).numpy().ToArray<long>();
prob = np.squeeze(prob, axis: 0);
var labelIndex = probLabel[0];

if (prob[probLabel[0]] < confidenceScore)
{
return string.Empty;
}

var labelIndex = probLabel[0];
return _labels[labelIndex];
}
public void InitClassifer(bool inference = true)
public void InitClassifer()
{
Reset();
Build();
LoadWeights(inference);
LoadWeights();
}

public void Train(TrainingParams trainingParams)
{
_inferenceMode = false;
Reset();
(var x, var y) = PrepareLoadData();
Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace BotSharp.Plugin.RoutingSpeeder.Providers.Models;
public class DialoguePredictionModel
{
public int Id { get; set; }
public string text { get; set; }
public string? label { get; set; }
public string? prediction { get; set; }
public string Text { get; set; }
public string? Label { get; set; }
public string? Prediction { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ namespace BotSharp.Plugin.RoutingSpeeder.Settings;
public class ClassifierSetting
{
public Dictionary<string, float> LabelMappingDict { get; set; } = new Dictionary<string, float>()
{
{"goodbye", 0f},
{"greeting", 1f},
{"other", 2f}
};
{
{"goodbye", 0f},
{"greeting", 1f},
{"other", 2f}
};

public string RAW_DATA_DIR { get; set; } = "raw_data";
public string MODEL_DIR { get; set; } = "models";
public string LABEL_FILE_NAME { get; set; } = "label.txt";
public string WEIGHT_FILE_NAME { get; set; } = "intent-classifier.h5";
}