Skip to content

Add intent classifier in routing speeder. #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace BotSharp.Core.Plugins.Knowledges;
namespace BotSharp.Abstraction.Knowledges.Settings;

public class KnowledgeBaseSettings
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BotSharp.Abstraction.Knowledges.Settings;
using BotSharp.Core.Plugins.Knowledges.Services;
using Microsoft.Extensions.Configuration;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using BotSharp.Abstraction.Knowledges.Models;
using BotSharp.Abstraction.Knowledges.Settings;
using BotSharp.Abstraction.MLTasks;
using BotSharp.Abstraction.VectorStorage;
using System.Text.Json;

namespace BotSharp.Core.Plugins.Knowledges.Services;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public async Task<string> RenderIntentResponse(string agentId, RoleDialogModel m
// Find response template
var agentService = _services.GetRequiredService<IAgentService>();
var dir = Path.Combine(agentService.GetAgentDataDir(agentId), "responses");
if (!Directory.Exists(dir))
{
return string.Empty;
}
var responses = Directory.GetFiles(dir)
.Where(f => f.Split(Path.DirectorySeparatorChar).Last().Split('.')[1] == message.IntentName)
.ToList();
Expand All @@ -62,8 +66,15 @@ public async Task<string> RenderIntentResponse(string agentId, RoleDialogModel m

// Convert args and execute data to dictionary
var dict = new Dictionary<string, object>();
ExtractArgs(JsonSerializer.Deserialize<JsonDocument>(message.FunctionArgs), dict);
ExtractExecuteData(message.ExecutionData, dict);
if (!string.IsNullOrEmpty(message.FunctionArgs))
{
ExtractArgs(JsonSerializer.Deserialize<JsonDocument>(message.FunctionArgs), dict);
}

if (message.ExecutionData != null)
{
ExtractExecuteData(message.ExecutionData, dict);
}

var text = render.Render(template, dict);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using UglyToad.PdfPig.Content;
using UglyToad.PdfPig;
using BotSharp.Core.Plugins.Knowledges;

using BotSharp.Abstraction.Knowledges.Settings;

namespace BotSharp.OpenAPI.Controllers;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
<VersionPrefix>0.11.0</VersionPrefix>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="TensorFlow.Keras" Version="0.11.2" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Infrastructure\BotSharp.Abstraction\BotSharp.Abstraction.csproj" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using System;
using System.IO;
using System.Text;
using System.Collections.Generic;
using Tensorflow;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using Tensorflow.Keras.Callbacks;
using System.Text.RegularExpressions;
using BotSharp.Plugin.RoutingSpeeder.Settings;
using BotSharp.Abstraction.MLTasks;
using BotSharp.Plugin.RoutingSpeeder.Providers.Models;
using Microsoft.Extensions.DependencyInjection;
using System.Linq;
using Tensorflow.Keras;
using BotSharp.Abstraction.Knowledges.Settings;
using System.Numerics;
using Newtonsoft.Json;
using Tensorflow.Keras.Layers;
using BotSharp.Abstraction.Agents;

namespace BotSharp.Plugin.RoutingSpeeder.Providers;

public class IntentClassifier
{
private readonly IServiceProvider _services;
Model _model;
public Model model => _model;
private bool _isModelReady;
public bool isModelReady => _isModelReady;
private ClassifierSetting _settings;

public IntentClassifier(IServiceProvider services, ClassifierSetting settings)
{
_services = services;
_settings = settings;
}

private void Reset()
{
keras.backend.clear_session();
_isModelReady = false;
}

private void Build()
{
if (_isModelReady)
{
return;
}

var vector = _services.GetRequiredService<ITextEmbedding>();

var layers = new List<ILayer>
{
keras.layers.InputLayer((vector.Dimension), name: "Input"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(256, activation:"relu"),
keras.layers.Dense(GetLabels().Length, activation: keras.activations.Softmax)
};
_model = keras.Sequential(layers);

#if DEBUG
Console.WriteLine();
_model.summary();
#endif
_isModelReady = true;
}

private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
{
_model.compile(optimizer: keras.optimizers.Adam(trainingParams.LearningRate),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" }
);

CallbackParams callback_parameters = new CallbackParams
{
Model = _model,
Epochs = trainingParams.Epochs,
Verbose = 1,
Steps = 10
};

ICallback earlyStop = new EarlyStopping(callback_parameters, "accuracy");

var callbacks = new List<ICallback>() { earlyStop };

var weights = LoadWeights();

_model.fit(x, y,
batch_size: trainingParams.BatchSize,
epochs: trainingParams.Epochs,
callbacks: callbacks,
// validation_split: 0.1f,
shuffle: true);

_model.save_weights(weights);

_isModelReady = true;
}

public string LoadWeights()
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();

var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5");
if (File.Exists(weightsFile))
{
_model.load_weights(weightsFile);
_isModelReady = true;
Console.WriteLine($"Successfully load the weights!");
}
else
{
Console.WriteLine("No available weights.");
}
return weightsFile;
}

public (NDArray x, NDArray y) Vectorize(List<DialoguePredictionModel> items)
{
var vector = _services.GetRequiredService<ITextEmbedding>();

var x = np.zeros((items.Count, vector.Dimension), dtype: np.float32);
var y = np.zeros((items.Count, 1), dtype: np.float32);

for (int i = 0; i < items.Count; i++)
{
x[i] = vector.GetVector(TextClean(items[i].text));
if (_settings.LabelMappingDict.ContainsKey(items[i].label))
{
y[i] = _settings.LabelMappingDict[items[i].label];
}
}
return (x, y);
}

public NDArray GetTextEmbedding(string text)
{
var knowledgeSettings = _services.GetRequiredService<KnowledgeBaseSettings>();
var embedding = _services.GetServices<ITextEmbedding>()
.FirstOrDefault(x => x.GetType().FullName.EndsWith(knowledgeSettings.TextEmbedding));

var x = np.zeros((1, embedding.Dimension), dtype: np.float32);
x[0] = embedding.GetVector(text);
return x;
}

public (NDArray, NDArray) PrepareLoadData()
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR);


if (!Directory.Exists(rootDirectory))
{
throw new Exception($"No training data found! Please put training data in this path: {rootDirectory}");
}

var vector = _services.GetRequiredService<ITextEmbedding>();


var vectorList = new List<float[]>();

var labelList = new List<string>();
foreach (var filePath in GetFiles())
{
var texts = File.ReadAllLines(filePath, Encoding.UTF8).Select(x => TextClean(x)).ToList();
vectorList.AddRange(vector.GetVectors(texts));
string fileName = Path.GetFileNameWithoutExtension(filePath);
labelList.AddRange(Enumerable.Repeat(fileName, texts.Count).ToList());
}

var uniqueLabelList = labelList.Distinct().ToList();

var x = np.zeros((vectorList.Count, vector.Dimension), dtype: np.float32);
var y = np.zeros((vectorList.Count, 1), dtype: np.float32);

for (int i = 0; i < vectorList.Count; i++)
{
x[i] = vectorList[i];
y[i] = (float)uniqueLabelList.IndexOf(labelList[i]);
}
return (x, y);
}

public string[] GetFiles()
{
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR);
return Directory.GetFiles(rootDirectory).OrderBy(x => x).ToArray();
}

public string[] GetLabels()
{
return GetFiles().Select(x => Path.GetFileNameWithoutExtension(x)).ToArray();
}

public string TextClean(string text)
{
// Remove punctuation
// Remove digits
// To lowercase
var processedText = Regex.Replace(text, "[AB0-9]", " ");
processedText = string.Join("", processedText.Select(c => char.IsPunctuation(c) ? ' ' : c).ToList());
processedText = processedText.Replace(" ", " ").ToLower();
return processedText;
}

public string Predict(NDArray vector)
{
if (!_isModelReady)
{
InitClassifer();
}

var prob = _model.predict(vector);
var probLabel = tf.arg_max(prob, -1).numpy();
// var prediction = _settings.LabelMappingDict.First(x => x.Value == probLabel[0]).Key;

var prediction = GetLabels()[probLabel[0]];
// var prediction = GetLabels().Where((x, i) => i == probLabel[0]).First();

return prediction;
}
public void InitClassifer()
{
Reset();
Build();
LoadWeights();
}

public void Train()
{
var trainingParams = new TrainingParams();
Reset();
Build();
(var x, var y) = PrepareLoadData();
Fit(x, y, trainingParams);

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace BotSharp.Plugin.RoutingSpeeder.Providers.Models;

public class DialoguePredictionModel
{
public int Id { get; set; }
public string text { get; set; }
public string? label { get; set; }
public string? prediction { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace BotSharp.Plugin.RoutingSpeeder.Providers.Models;

public class TrainingParams
{
public int ClientId { get; set; }
public int Epochs { get; set; } = 10;
public int BatchSize { get; set; } = 16;
public float LearningRate { get; set; } = 1.0e-4f;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@
using BotSharp.Abstraction.Agents.Models;
using BotSharp.Abstraction.Conversations;
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Abstraction.Templating;
using BotSharp.Abstraction.MLTasks;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Linq;
using System.Threading.Tasks;
using BotSharp.Plugin.RoutingSpeeder.Settings;
using BotSharp.Abstraction.Templating;
using BotSharp.Plugin.RoutingSpeeder.Providers;
using System.Runtime.InteropServices;

namespace BotSharp.Plugin.RoutingSpeeder;

public class RoutingConversationHook: ConversationHookBase
{
private readonly IServiceProvider _services;
public RoutingConversationHook(IServiceProvider services)
private RouterSpeederSettings _settings;
public RoutingConversationHook(IServiceProvider service, RouterSpeederSettings settings)
{
_services = services;
_services = service;
_settings = settings;
}

public override async Task BeforeCompletion(RoleDialogModel message)
{
var intentClassifier = _services.GetRequiredService<IntentClassifier>();
var vector = intentClassifier.GetTextEmbedding(message.Content);

// intentClassifier.Train();
// Utilize local discriminative model to predict intent
message.IntentName = "greeting";
var predText = intentClassifier.Predict(vector);

message.IntentName = predText;

// Render by template
var templateService = _services.GetRequiredService<IResponseTemplateService>();
Expand Down
Loading