Skip to content

feat(Gemini): Added Text Embedding Model #969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="EntityFramework" Version="6.4.4" />
<PackageVersion Include="Google_GenerativeAI" Version="2.4.6" />
<PackageVersion Include="LLMSharp.Google.Palm" Version="1.0.2" />
<PackageVersion Include="Microsoft.AspNetCore.Http.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageVersion Include="Microsoft.AspNetCore.StaticFiles" Version="$(AspNetCoreVersion)" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
Expand Down Expand Up @@ -42,8 +44,6 @@
<PackageVersion Include="Microsoft.Data.Sqlite" Version="8.0.8" />
<PackageVersion Include="MySql.Data" Version="9.0.0" />
<PackageVersion Include="NPOI" Version="2.7.1" />
<PackageVersion Include="LLMSharp.Google.Palm" Version="1.0.2" />
<PackageVersion Include="Mscc.GenerativeAI" Version="2.2.11" />
<PackageVersion Include="Microsoft.AspNetCore.Mvc.Core" Version="2.2.5" />
<PackageVersion Include="Refit" Version="8.0.0" />
<PackageVersion Include="Refit.HttpClientFactory" Version="8.0.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
<GenerateDocumentationFile>$(GenerateDocumentationFile)</GenerateDocumentationFile>
<OutputPath>$(SolutionDir)packages</OutputPath>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="LLMSharp.Google.Palm" />
<PackageReference Include="Mscc.GenerativeAI" />
<ProjectReference Include="..\..\Infrastructure\BotSharp.Abstraction\BotSharp.Abstraction.csproj" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Infrastructure\BotSharp.Abstraction\BotSharp.Abstraction.csproj" />
<PackageReference Include="Google_GenerativeAI" />
<PackageReference Include="LLMSharp.Google.Palm" />
</ItemGroup>

</Project>
2 changes: 2 additions & 0 deletions src/Plugins/BotSharp.Plugin.GoogleAI/GoogleAiPlugin.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using BotSharp.Abstraction.Plugins;
using BotSharp.Abstraction.Settings;
using BotSharp.Plugin.GoogleAi.Providers.Chat;
using BotSharp.Plugin.GoogleAI.Providers.Embedding;
using BotSharp.Plugin.GoogleAi.Providers.Text;

namespace BotSharp.Plugin.GoogleAi;
Expand All @@ -23,5 +24,6 @@ public void RegisterDI(IServiceCollection services, IConfiguration config)
services.AddScoped<ITextCompletion, GeminiTextCompletionProvider>();
services.AddScoped<IChatCompletion, PalmChatCompletionProvider>();
services.AddScoped<IChatCompletion, GeminiChatCompletionProvider>();
services.AddScoped<ITextEmbedding, TextEmbeddingProvider>();
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using System.Text.Json.Nodes;
using BotSharp.Abstraction.Agents;
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.Conversations;
using BotSharp.Abstraction.Loggers;
using GenerativeAI;
using GenerativeAI.Core;
using GenerativeAI.Types;
using Microsoft.Extensions.Logging;
using Mscc.GenerativeAI;

namespace BotSharp.Plugin.GoogleAi.Providers.Chat;

Expand Down Expand Up @@ -37,24 +40,24 @@ public async Task<RoleDialogModel> GetChatCompletions(Agent agent, List<RoleDial
}

var client = ProviderHelper.GetGeminiClient(Provider, _model, _services);
var aiModel = client.GenerativeModel(_model);
var aiModel = client.CreateGenerativeModel(_model);
var (prompt, request) = PrepareOptions(aiModel, agent, conversations);

var response = await aiModel.GenerateContent(request);
var candidate = response.Candidates.First();
var part = candidate.Content?.Parts?.FirstOrDefault();
var response = await aiModel.GenerateContentAsync(request);
var candidate = response.Candidates?.First();
var part = candidate?.Content?.Parts?.FirstOrDefault();
var text = part?.Text ?? string.Empty;

RoleDialogModel responseMessage;
if (part?.FunctionCall != null)
if (response.GetFunction()!=null)
{
responseMessage = new RoleDialogModel(AgentRole.Function, text)
{
CurrentAgentId = agent.Id,
MessageId = conversations.LastOrDefault()?.MessageId ?? string.Empty,
ToolCallId = part.FunctionCall.Name,
FunctionName = part.FunctionCall.Name,
FunctionArgs = part.FunctionCall.Args?.ToString(),
FunctionArgs = part.FunctionCall.Args?.ToJsonString(),
RenderedInstruction = string.Join("\r\n", renderedInstructions)
};
}
Expand Down Expand Up @@ -82,14 +85,112 @@ public async Task<RoleDialogModel> GetChatCompletions(Agent agent, List<RoleDial
return responseMessage;
}

public Task<bool> GetChatCompletionsAsync(Agent agent, List<RoleDialogModel> conversations, Func<RoleDialogModel, Task> onMessageReceived, Func<RoleDialogModel, Task> onFunctionExecuting)
public async Task<bool> GetChatCompletionsAsync(Agent agent, List<RoleDialogModel> conversations, Func<RoleDialogModel, Task> onMessageReceived, Func<RoleDialogModel, Task> onFunctionExecuting)
{
throw new NotImplementedException();
var hooks = _services.GetServices<IContentGeneratingHook>().ToList();

// Before chat completion hook
foreach (var hook in hooks)
{
await hook.BeforeGenerating(agent, conversations);
}

var client = ProviderHelper.GetGeminiClient(Provider, _model, _services);
var chatClient = client.CreateGeminiModel(_model);
var (prompt, messages) = PrepareOptions(chatClient,agent, conversations);

var response = await chatClient.GenerateContentAsync(messages);

var candidate = response.Candidates?.First();
var part = candidate?.Content?.Parts?.FirstOrDefault();
var text = part?.Text ?? string.Empty;

var msg = new RoleDialogModel(AgentRole.Assistant, text)
{
CurrentAgentId = agent.Id,
RenderedInstruction = string.Join("\r\n", renderedInstructions)
};

// After chat completion hook
foreach (var hook in hooks)
{
await hook.AfterGenerated(msg, new TokenStatsModel
{
Prompt = prompt,
Provider = Provider,
Model = _model,
PromptCount = response?.UsageMetadata?.PromptTokenCount ?? 0,
CompletionCount = response?.UsageMetadata?.CandidatesTokenCount ?? 0
});
}

if (response.GetFunction()!=null)
{
var toolCall = response.GetFunction();
_logger.LogInformation($"[{agent.Name}]: {toolCall?.Name}({toolCall?.Args?.ToJsonString()})");

var funcContextIn = new RoleDialogModel(AgentRole.Function, text)
{
CurrentAgentId = agent.Id,
MessageId = conversations.LastOrDefault()?.MessageId ?? string.Empty,
ToolCallId = toolCall?.Id,
FunctionName = toolCall?.Name,
FunctionArgs = toolCall?.Args?.ToJsonString(),
RenderedInstruction = string.Join("\r\n", renderedInstructions)
};

// Somethings LLM will generate a function name with agent name.
if (!string.IsNullOrEmpty(funcContextIn.FunctionName))
{
funcContextIn.FunctionName = funcContextIn.FunctionName.Split('.').Last();
}

// Execute functions
await onFunctionExecuting(funcContextIn);
}
else
{
// Text response received
await onMessageReceived(msg);
}

return true;
}

public Task<bool> GetChatCompletionsStreamingAsync(Agent agent, List<RoleDialogModel> conversations, Func<RoleDialogModel, Task> onMessageReceived)
public async Task<bool> GetChatCompletionsStreamingAsync(Agent agent, List<RoleDialogModel> conversations, Func<RoleDialogModel, Task> onMessageReceived)
{
throw new NotImplementedException();
var client = ProviderHelper.GetGeminiClient(Provider, _model, _services);
var chatClient = client.CreateGenerativeModel(_model);
var (prompt, messages) = PrepareOptions(chatClient,agent, conversations);

var asyncEnumerable = chatClient.StreamContentAsync(messages);

await foreach (var response in asyncEnumerable)
{
if (response.GetFunction()!=null)
{
var func = response.GetFunction();
var update =func?.Args?.ToJsonString().ToString() ?? string.Empty;
_logger.LogInformation(update);

await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, update)
{
RenderedInstruction = string.Join("\r\n", renderedInstructions)
});
continue;
}

if (response.Text().IsNullOrEmpty()) continue;

_logger.LogInformation(response.Text());

await onMessageReceived(new RoleDialogModel(response.Candidates?.LastOrDefault()?.Content?.Role?.ToString() ?? AgentRole.Assistant.ToString(), response.Text() ?? string.Empty)
{
RenderedInstruction = string.Join("\r\n", renderedInstructions)
});
}

return true;
}

public void SetModelName(string model)
Expand All @@ -107,6 +208,10 @@ public void SetModelName(string model)
aiModel.UseGoogleSearch = googleSettings.Gemini.UseGoogleSearch;
aiModel.UseGrounding = googleSettings.Gemini.UseGrounding;

aiModel.FunctionCallingBehaviour = new FunctionCallingBehaviour()
{
AutoCallFunction = false
};
// Assembly messages
var contents = new List<Content>();
var tools = new List<Tool>();
Expand All @@ -116,11 +221,7 @@ public void SetModelName(string model)
if (!string.IsNullOrEmpty(agent.Instruction) || !agent.SecondaryInstructions.IsNullOrEmpty())
{
var instruction = agentService.RenderedInstruction(agent);
contents.Add(new Content(instruction)
{
Role = AgentRole.User
});

contents.Add(new Content(instruction, AgentRole.User));
renderedInstructions.Add(instruction);
systemPrompts.Add(instruction);
}
Expand All @@ -135,7 +236,7 @@ public void SetModelName(string model)
var props = JsonSerializer.Serialize(def?.Properties);
var parameters = !string.IsNullOrWhiteSpace(props) && props != "{}" ? new Schema()
{
Type = ParameterType.Object,
Type = "object",
Properties = JsonSerializer.Deserialize<dynamic>(props),
Required = def?.Required ?? []
} : null;
Expand All @@ -160,36 +261,33 @@ public void SetModelName(string model)
{
if (message.Role == AgentRole.Function)
{
contents.Add(new Content(message.Content)
contents.Add( new Content(message.Content,AgentRole.Function)
{
Role = AgentRole.Function,
Parts = new()
{
new FunctionCall
Parts =
[
new Part()
{
Name = message.FunctionName,
Args = JsonSerializer.Deserialize<object>(message.FunctionArgs ?? "{}")
FunctionCall = new FunctionCall
{
Name = message.FunctionName,
Args = JsonNode.Parse(message.FunctionArgs ?? "{}")
}
}
}
]
});

convPrompts.Add($"{AgentRole.Assistant}: Call function {message.FunctionName}({message.FunctionArgs})");
}
else if (message.Role == AgentRole.User)
{
var text = !string.IsNullOrWhiteSpace(message.Payload) ? message.Payload : message.Content;
contents.Add(new Content(text)
{
Role = AgentRole.User
});
contents.Add(new Content(text, AgentRole.User));
convPrompts.Add($"{AgentRole.User}: {text}");
}
else if (message.Role == AgentRole.Assistant)
{
contents.Add(new Content(message.Content)
{
Role = AgentRole.Model
});
contents.Add(new Content(message.Content, AgentRole.Model));
convPrompts.Add($"{AgentRole.Assistant}: {message.Content}");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class PalmChatCompletionProvider : IChatCompletion

public string Provider => "google-palm";
public string Model => _model;

public PalmChatCompletionProvider(
IServiceProvider services,
ILogger<PalmChatCompletionProvider> logger)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using BotSharp.Plugin.GoogleAi.Providers;
using GenerativeAI;
using GenerativeAI.Types;
using Microsoft.Extensions.Logging;

namespace BotSharp.Plugin.GoogleAI.Providers.Embedding;

public class TextEmbeddingProvider : ITextEmbedding
{
protected readonly GoogleAiSettings _settings;
protected readonly IServiceProvider _services;
protected readonly ILogger<TextEmbeddingProvider> _logger;

private const int DEFAULT_DIMENSION = 1536;
protected string _model = GoogleAIModels.TextEmbedding;
protected int _dimension = DEFAULT_DIMENSION;

public virtual string Provider => "google-ai";
public string Model => _model;

public TextEmbeddingProvider(
GoogleAiSettings settings,
ILogger<TextEmbeddingProvider> logger,
IServiceProvider services)
{
_settings = settings;
_logger = logger;
_services = services;
}

public async Task<float[]> GetVectorAsync(string text)
{
var client = ProviderHelper.GetGeminiClient(Provider, _model, _services);
var embeddingClient = client.CreateEmbeddingModel(_model);

var response = await embeddingClient.EmbedContentAsync(text);
var value = response?.Embedding?.Values;
return value.ToArray();
}

public async Task<List<float[]>> GetVectorsAsync(List<string> texts)
{
var client = ProviderHelper.GetGeminiClient(Provider, _model, _services);
var embeddingClient = client.CreateEmbeddingModel(_model);

var response = await embeddingClient.BatchEmbedContentAsync(texts.Select(s=>new Content(s, Roles.User)));
var value = response.Embeddings;
if (value == null)
return new List<float[]>();
return value.Select(x => x.Values?.ToArray()??[]).ToList();
}

public void SetModelName(string model)
{
_model = model;
}

public void SetDimension(int dimension)
{
_dimension = dimension > 0 ? dimension : DEFAULT_DIMENSION;
}

public int GetDimension()
{
return _dimension;
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
using LLMSharp.Google.Palm;
using Mscc.GenerativeAI;

namespace BotSharp.Plugin.GoogleAi.Providers;

public static class ProviderHelper
{
public static GoogleAI GetGeminiClient(string provider, string model, IServiceProvider services)
public static GenerativeAI.GoogleAi GetGeminiClient(string provider, string model, IServiceProvider services)
{
var settingsService = services.GetRequiredService<ILlmProviderService>();
var settings = settingsService.GetSetting(provider, model);
var client = new GoogleAI(settings.ApiKey);
var client = new GenerativeAI.GoogleAi(settings.ApiKey);
return client;
}

Expand Down
Loading
Loading