Skip to content

Add filter when saving dialogue #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Mvc.Core" Version="2.2.5" />
<PackageReference Include="TensorFlow.Keras" Version="0.11.2" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.Threading.Tasks;
using BotSharp.Plugin.RoutingSpeeder.Providers;
using BotSharp.Plugin.RoutingSpeeder.Providers.Models;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;

namespace BotSharp.Plugin.RoutingSpeeder.Controllers;

[AllowAnonymous]
public class RoutingSpeederController : ControllerBase
{
private readonly IServiceProvider _service;
public RoutingSpeederController(IServiceProvider service)
{
_service = service;
}

[HttpPost("/routing-speeder/classifier/train")]
public IActionResult TrainIntentClassifier(TrainingParams trainingParams)
{
var intentClassifier = _service.GetRequiredService<IntentClassifier>();
intentClassifier.InitClassifer(trainingParams.Inference);
intentClassifier.Train(trainingParams);
return Ok(intentClassifier.Labels);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,46 @@
using System.Text.RegularExpressions;
using BotSharp.Plugin.RoutingSpeeder.Settings;
using BotSharp.Abstraction.MLTasks;
using BotSharp.Abstraction.Knowledges.Settings;
using BotSharp.Plugin.RoutingSpeeder.Providers.Models;
using Microsoft.Extensions.DependencyInjection;
using System.Linq;
using Tensorflow.Keras;
using BotSharp.Abstraction.Knowledges.Settings;
using System.Numerics;
using Newtonsoft.Json;
using Tensorflow.Keras.Layers;
using BotSharp.Abstraction.Agents;
using BotSharp.Abstraction.Knowledges;

namespace BotSharp.Plugin.RoutingSpeeder.Providers;

public class IntentClassifier
{
private readonly IServiceProvider _services;
private KnowledgeBaseSettings _knowledgeBaseSettings;
Model _model;
public Model model => _model;
private bool _isModelReady;
public bool isModelReady => _isModelReady;
private ClassifierSetting _settings;

public IntentClassifier(IServiceProvider services, ClassifierSetting settings)
private string[] _labels;

public string[] Labels => GetLabels();

private int _numLabels
{
get
{
return Labels.Length;
}
}

public IntentClassifier(IServiceProvider services, ClassifierSetting settings, KnowledgeBaseSettings knowledgeBaseSettings)
{
_services = services;
_settings = settings;
_knowledgeBaseSettings = knowledgeBaseSettings;
}

private void Reset()
Expand All @@ -50,17 +65,16 @@ private void Build()
{
return;
}

var vector = _services.GetRequiredService<ITextEmbedding>();

var labels = GetLabels();
var vector = _services.GetServices<ITextEmbedding>()
.FirstOrDefault(x => x.GetType().FullName.EndsWith(_knowledgeBaseSettings.TextEmbedding));

var layers = new List<ILayer>
{
keras.layers.InputLayer((vector.Dimension), name: "Input"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(labels.Length, activation: keras.activations.Softmax)
keras.layers.Dense(_numLabels, activation: keras.activations.Softmax)
};
_model = keras.Sequential(layers);

Expand Down Expand Up @@ -90,7 +104,7 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)

var callbacks = new List<ICallback>() { earlyStop };

var weights = LoadWeights();
var weights = LoadWeights(trainingParams.Inference);

_model.fit(x, y,
batch_size: trainingParams.BatchSize,
Expand All @@ -104,42 +118,27 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
_isModelReady = true;
}

public string LoadWeights()
public string LoadWeights(bool inference = true)
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();

var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5");
if (File.Exists(weightsFile))

if (File.Exists(weightsFile) && inference)
{
_model.load_weights(weightsFile);
_isModelReady = true;
Console.WriteLine($"Successfully load the weights!");

}
else
{
Console.WriteLine("No available weights.");
var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local";
Console.WriteLine(logInfo);
}
return weightsFile;
}

public (NDArray x, NDArray y) Vectorize(List<DialoguePredictionModel> items)
{
var vector = _services.GetRequiredService<ITextEmbedding>();

var x = np.zeros((items.Count, vector.Dimension), dtype: np.float32);
var y = np.zeros((items.Count, 1), dtype: np.float32);

for (int i = 0; i < items.Count; i++)
{
x[i] = vector.GetVector(TextClean(items[i].text));
if (_settings.LabelMappingDict.ContainsKey(items[i].label))
{
y[i] = _settings.LabelMappingDict[items[i].label];
}
}
return (x, y);
}

public NDArray GetTextEmbedding(string text)
{
var knowledgeSettings = _services.GetRequiredService<KnowledgeBaseSettings>();
Expand All @@ -164,10 +163,10 @@ public NDArray GetTextEmbedding(string text)

var vector = _services.GetRequiredService<ITextEmbedding>();


var vectorList = new List<float[]>();

var labelList = new List<string>();

foreach (var filePath in GetFiles())
{
var texts = File.ReadAllLines(filePath, Encoding.UTF8).Select(x => TextClean(x)).ToList();
Expand All @@ -192,19 +191,24 @@ public NDArray GetTextEmbedding(string text)
return (x, y);
}

public string[] GetFiles()
public string[] GetFiles(string prefix = "intent")
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR);
return Directory.GetFiles(rootDirectory).OrderBy(x => x).ToArray();
return Directory.GetFiles(rootDirectory).Where(x => Path.GetFileNameWithoutExtension(x).StartsWith(prefix)).OrderBy(x => x).ToArray();
}

public string[] GetLabels()
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME);
var labelText = File.ReadAllLines(rootDirectory);
return labelText.OrderBy(x => x).ToArray();
if (_labels == null)
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME);
var labelText = File.ReadAllLines(rootDirectory);
_labels = labelText.OrderBy(x => x).ToArray();
}

return _labels;
}

public string TextClean(string text)
Expand Down Expand Up @@ -235,24 +239,22 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
return string.Empty;
}

var prediction = GetLabels()[probLabel[0]];
var prediction = _labels[probLabel[0]];

return prediction;
}
public void InitClassifer()
public void InitClassifer(bool inference = true)
{
Reset();
Build();
LoadWeights();
LoadWeights(inference);
}

public void Train()
public void Train(TrainingParams trainingParams)
{
var trainingParams = new TrainingParams();
Reset();
(var x, var y) = PrepareLoadData();
Build();
Fit(x, y, trainingParams);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ public class TrainingParams
public int Epochs { get; set; } = 10;
public int BatchSize { get; set; } = 16;
public float LearningRate { get; set; } = 1.0e-4f;
public bool Inference { get; set; } = false;
}
Binary file modified src/WebStarter/data/models/intent-classifier.h5
Binary file not shown.