Skip to content

Commit b5b0279

Browse files
authored
Merge pull request #360 from hchen2020/master
Support model_id in InstructMode
2 parents 4d99bf7 + 7fb5410 commit b5b0279

File tree

4 files changed

+49
-37
lines changed

4 files changed

+49
-37
lines changed

src/Infrastructure/BotSharp.Abstraction/Models/MessageConfig.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ public class MessageConfig : TruncateMessageRequest
1414
[JsonPropertyName("model")]
1515
public virtual string? Model { get; set; } = null;
1616

17+
/// <summary>
18+
/// Model name
19+
/// </summary>
20+
[JsonPropertyName("model_id")]
21+
public virtual string? ModelId { get; set; } = null;
22+
1723
/// <summary>
1824
/// The sampling temperature to use that controls the apparent creativity of generated completions.
1925
/// </summary>

src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using BotSharp.Abstraction.Agents.Models;
21
using BotSharp.Abstraction.MLTasks;
32
using BotSharp.Abstraction.MLTasks.Settings;
43

@@ -11,31 +10,25 @@ public static object GetCompletion(IServiceProvider services,
1110
string? model = null,
1211
AgentLlmConfig? agentConfig = null)
1312
{
14-
var state = services.GetRequiredService<IConversationStateService>();
15-
var agentSetting = services.GetRequiredService<AgentSettings>();
16-
17-
if (string.IsNullOrEmpty(provider))
18-
{
19-
provider = agentConfig?.Provider ?? agentSetting.LlmConfig?.Provider;
20-
provider = state.GetState("provider", provider ?? "azure-openai");
21-
}
13+
var settingsService = services.GetRequiredService<ILlmProviderService>();
2214

23-
if (string.IsNullOrEmpty(model))
24-
{
25-
model = agentConfig?.Model ?? agentSetting.LlmConfig?.Model;
26-
model = state.GetState("model", model ?? "gpt-35-turbo-4k");
27-
}
15+
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, agentConfig: agentConfig);
2816

29-
var settingsService = services.GetRequiredService<ILlmProviderService>();
3017
var settings = settingsService.GetSetting(provider, model);
3118

3219
if (settings.Type == LlmModelType.Text)
3320
{
34-
return GetTextCompletion(services, provider: provider, model: model);
21+
return GetTextCompletion(services,
22+
provider: provider,
23+
model: model,
24+
agentConfig: agentConfig);
3525
}
3626
else
3727
{
38-
return GetChatCompletion(services, provider: provider, model: model);
28+
return GetChatCompletion(services,
29+
provider: provider,
30+
model: model,
31+
agentConfig: agentConfig);
3932
}
4033
}
4134

@@ -45,20 +38,7 @@ public static IChatCompletion GetChatCompletion(IServiceProvider services,
4538
AgentLlmConfig? agentConfig = null)
4639
{
4740
var completions = services.GetServices<IChatCompletion>();
48-
var agentSetting = services.GetRequiredService<AgentSettings>();
49-
var state = services.GetRequiredService<IConversationStateService>();
50-
51-
if (string.IsNullOrEmpty(provider))
52-
{
53-
provider = agentConfig?.Provider ?? agentSetting.LlmConfig?.Provider;
54-
provider = state.GetState("provider", provider ?? "azure-openai");
55-
}
56-
57-
if (string.IsNullOrEmpty(model))
58-
{
59-
model = agentConfig?.Model ?? agentSetting.LlmConfig?.Model;
60-
model = state.GetState("model", model ?? "gpt-35-turbo-4k");
61-
}
41+
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, agentConfig: agentConfig);
6242

6343
var completer = completions.FirstOrDefault(x => x.Provider == provider);
6444
if (completer == null)
@@ -72,12 +52,11 @@ public static IChatCompletion GetChatCompletion(IServiceProvider services,
7252
return completer;
7353
}
7454

75-
public static ITextCompletion GetTextCompletion(IServiceProvider services,
76-
string? provider = null,
55+
private static (string, string) GetProviderAndModel(IServiceProvider services,
56+
string? provider = null,
7757
string? model = null,
7858
AgentLlmConfig? agentConfig = null)
7959
{
80-
var completions = services.GetServices<ITextCompletion>();
8160
var agentSetting = services.GetRequiredService<AgentSettings>();
8261
var state = services.GetRequiredService<IConversationStateService>();
8362

@@ -90,9 +69,32 @@ public static ITextCompletion GetTextCompletion(IServiceProvider services,
9069
if (string.IsNullOrEmpty(model))
9170
{
9271
model = agentConfig?.Model ?? agentSetting.LlmConfig?.Model;
93-
model = state.GetState("model", model ?? "gpt-35-turbo-instruct");
72+
if (state.ContainsState("model"))
73+
{
74+
model = state.GetState("model", model ?? "gpt-35-turbo-4k");
75+
}
76+
else if (state.ContainsState("model_id"))
77+
{
78+
var modelId = state.GetState("model_id");
79+
var llmProviderService = services.GetRequiredService<ILlmProviderService>();
80+
model = llmProviderService.GetProviderModel(provider, modelId)?.Name;
81+
}
9482
}
9583

84+
state.SetState("provider", provider);
85+
state.SetState("model", model);
86+
87+
return (provider, model);
88+
}
89+
public static ITextCompletion GetTextCompletion(IServiceProvider services,
90+
string? provider = null,
91+
string? model = null,
92+
AgentLlmConfig? agentConfig = null)
93+
{
94+
var completions = services.GetServices<ITextCompletion>();
95+
96+
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, agentConfig: agentConfig);
97+
9698
var completer = completions.FirstOrDefault(x => x.Provider == provider);
9799
if (completer == null)
98100
{

src/Infrastructure/BotSharp.Core/Instructs/InstructService.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public async Task<InstructResult> Execute(string agentId, RoleDialogModel messag
6868

6969
var completer = CompletionProvider.GetCompletion(_services,
7070
agentConfig: agent.LlmConfig);
71+
7172
var response = new InstructResult
7273
{
7374
MessageId = message.MessageId

src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public async Task<InstructResult> InstructCompletion([FromRoute] string agentId,
2525
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
2626
state.SetState("provider", input.Provider)
2727
.SetState("model", input.Model)
28+
.SetState("model_id", input.ModelId)
2829
.SetState("instruction", input.Instruction)
2930
.SetState("input_text", input.Text);
3031

@@ -45,7 +46,8 @@ public async Task<string> TextCompletion([FromBody] IncomingMessageModel input)
4546
var state = _services.GetRequiredService<IConversationStateService>();
4647
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
4748
state.SetState("provider", input.Provider)
48-
.SetState("model", input.Model);
49+
.SetState("model", input.Model)
50+
.SetState("model_id", input.ModelId);
4951

5052
var textCompletion = CompletionProvider.GetTextCompletion(_services);
5153
return await textCompletion.GetCompletion(input.Text, Guid.Empty.ToString(), Guid.NewGuid().ToString());
@@ -57,7 +59,8 @@ public async Task<string> ChatCompletion([FromBody] IncomingMessageModel input)
5759
var state = _services.GetRequiredService<IConversationStateService>();
5860
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
5961
state.SetState("provider", input.Provider)
60-
.SetState("model", input.Model);
62+
.SetState("model", input.Model)
63+
.SetState("model_id", input.ModelId);
6164

6265
var textCompletion = CompletionProvider.GetChatCompletion(_services);
6366
var message = await textCompletion.GetChatCompletions(new Agent()

0 commit comments

Comments
 (0)