Skip to content

refine state #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BotSharp.Abstraction.Conversations.Models;

public class ConversationState : Dictionary<string, List<StateValue>>
public class ConversationState : Dictionary<string, StateKeyValue>
{
public ConversationState()
{
Expand All @@ -11,7 +11,7 @@ public ConversationState(List<StateKeyValue> pairs)
{
foreach (var pair in pairs)
{
this[pair.Key] = pair.Values;
this[pair.Key] = pair;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace BotSharp.Abstraction.Conversations.Models;
public class StateKeyValue
{
public string Key { get; set; }
public bool Versioning { get; set; }
public List<StateValue> Values { get; set; } = new List<StateValue>();

public StateKeyValue()
Expand All @@ -20,6 +21,13 @@ public StateKeyValue(string key, List<StateValue> values)
public class StateValue
{
public string Data { get; set; }

[JsonPropertyName("message_id")]
public string MessageId { get; set; }

public bool Active { get; set; }

[JsonPropertyName("update_time")]
public DateTime UpdateTime { get; set; }

public StateValue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public async Task<bool> SendMessage(string agentId,
#endif

message.CurrentAgentId = agent.Id;
message.CreatedAt = DateTime.UtcNow;
if (string.IsNullOrEmpty(message.SenderId))
{
message.SenderId = _user.Id;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using BotSharp.Abstraction.Conversations.Models;

namespace BotSharp.Core.Conversations.Services;

/// <summary>
Expand Down Expand Up @@ -42,32 +44,42 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
var currentValue = value.ToString();
var hooks = _services.GetServices<IConversationHook>();

if (_states.TryGetValue(name, out var values))
if (ContainsState(name) && _states.TryGetValue(name, out var pair))
{
preValue = values?.LastOrDefault()?.Data ?? string.Empty;
preValue = pair?.Values.LastOrDefault()?.Data ?? string.Empty;
}

if (!_states.ContainsKey(name) || preValue != currentValue)
if (!ContainsState(name) || preValue != currentValue)
{
_logger.LogInformation($"[STATE] {name} = {value}");
foreach (var hook in hooks)
{
hook.OnStateChanged(name, preValue, currentValue).Wait();
}

var stateValue = new StateValue
var routingCtx = _services.GetRequiredService<IRoutingContext>();
var newPair = new StateKeyValue
{
Key = name,
Versioning = isNeedVersion
};

var newValue = new StateValue
{
Data = currentValue,
UpdateTime = DateTime.UtcNow
MessageId = routingCtx.MessageId,
Active = true,
UpdateTime = DateTime.UtcNow,
};

if (!_states.ContainsKey(name) || !isNeedVersion)
if (!isNeedVersion || !_states.ContainsKey(name))
{
_states[name] = new List<StateValue> { stateValue };
newPair.Values = new List<StateValue> { newValue };
_states[name] = newPair;
}
else
{
_states[name].Add(stateValue);
_states[name].Values.Add(newValue);
}
}

Expand All @@ -85,9 +97,12 @@ public Dictionary<string, string> Load(string conversationId)
{
foreach (var state in _states)
{
var value = state.Value?.LastOrDefault()?.Data ?? string.Empty;
curStates[state.Key] = value;
_logger.LogInformation($"[STATE] {state.Key} : {value}");
var value = state.Value?.Values?.LastOrDefault();
if (value == null || !value.Active) continue;

var data = value.Data ?? string.Empty;
curStates[state.Key] = data;
_logger.LogInformation($"[STATE] {state.Key} : {data}");
}
}

Expand All @@ -112,7 +127,7 @@ public void Save()

foreach (var dic in _states)
{
states.Add(new StateKeyValue(dic.Key, dic.Value));
states.Add(dic.Value);
}

_db.UpdateConversationStates(_conversationId, states);
Expand All @@ -121,27 +136,46 @@ public void Save()

public void CleanStates()
{
_states.Clear();
var utcNow = DateTime.UtcNow;
foreach (var key in _states.Keys)
{
var value = _states[key];
if (value == null || !value.Versioning || value.Values.IsNullOrEmpty()) continue;

var lastValue = value.Values.LastOrDefault();
if (lastValue == null || !lastValue.Active) continue;

value.Values.Add(new StateValue
{
Data = lastValue.Data,
MessageId = lastValue.MessageId,
Active = false,
UpdateTime = utcNow
});
}
}

public Dictionary<string, string> GetStates()
{
var curStates = new Dictionary<string, string>();
foreach (var state in _states)
{
curStates[state.Key] = state.Value?.LastOrDefault()?.Data ?? string.Empty;
var value = state.Value?.Values?.LastOrDefault();
if (value == null || !value.Active) continue;

curStates[state.Key] = value.Data ?? string.Empty;
}
return curStates;
}

public string GetState(string name, string defaultValue = "")
{
if (!_states.ContainsKey(name) || _states[name].IsNullOrEmpty())
if (!_states.ContainsKey(name) || _states[name].Values.IsNullOrEmpty() || !_states[name].Values.Last().Active)
{
return defaultValue;
}

return _states[name].Last().Data;
return _states[name].Values.Last().Data;
}

public void Dispose()
Expand All @@ -152,8 +186,9 @@ public void Dispose()
public bool ContainsState(string name)
{
return _states.ContainsKey(name)
&& !_states[name].IsNullOrEmpty()
&& !string.IsNullOrEmpty(_states[name].Last().Data);
&& !_states[name].Values.IsNullOrEmpty()
&& _states[name].Values.LastOrDefault()?.Active == true
&& !string.IsNullOrEmpty(_states[name].Values.Last().Data);
}

public void SaveStateByArgs(JsonDocument args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void Append(string conversationId, RoleDialogModel dialog)
AgentId = agentId,
MessageId = dialog.MessageId,
FunctionName = dialog.FunctionName,
CreateTime = DateTime.UtcNow
CreateTime = dialog.CreatedAt
};

var content = dialog.Content.RemoveNewLine();
Expand All @@ -65,7 +65,7 @@ public void Append(string conversationId, RoleDialogModel dialog)
MessageId = dialog.MessageId,
SenderId = dialog.SenderId,
FunctionName = dialog.FunctionName,
CreateTime = DateTime.UtcNow
CreateTime = dialog.CreatedAt
};

var content = dialog.Content.RemoveNewLine();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ public bool TruncateConversation(string conversationId, string messageId, bool c
var refTime = dialogs.ElementAt(foundIdx).MetaData.CreateTime;
var stateDir = Path.Combine(convDir, STATE_FILE);
var states = CollectConversationStates(stateDir);
isSaved = HandleTruncatedStates(stateDir, states, refTime);
isSaved = HandleTruncatedStates(stateDir, states, messageId, refTime);

// Handle truncated breakpoints
var breakpointDir = Path.Combine(convDir, BREAKPOINT_FILE);
Expand Down Expand Up @@ -597,12 +597,20 @@ private bool HandleTruncatedDialogs(string convDir, string dialogDir, List<Dialo
return isSaved;
}

private bool HandleTruncatedStates(string stateDir, List<StateKeyValue> states, DateTime refTime)
private bool HandleTruncatedStates(string stateDir, List<StateKeyValue> states, string refMsgId, DateTime refTime)
{
var truncatedStates = new List<StateKeyValue>();
foreach (var state in states)
{
var values = state.Values.Where(x => x.UpdateTime < refTime).ToList();
if (!state.Versioning)
{
truncatedStates.Add(state);
continue;
}

var values = state.Values.Where(x => x.MessageId != refMsgId)
.Where(x => x.UpdateTime < refTime)
.ToList();
if (values.Count == 0) continue;

state.Values = values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ public void SaveConversationContentLog(ContentLogOutputModel log)
log.MessageId = log.MessageId.IfNullOrEmptyAs(Guid.NewGuid().ToString());

var convDir = FindConversationDirectory(log.ConversationId);
if (string.IsNullOrEmpty(convDir))
{
convDir = Path.Combine(_dbSettings.FileRepository, _conversationSettings.DataDir, log.ConversationId);
Directory.CreateDirectory(convDir);
}
if (string.IsNullOrEmpty(convDir)) return;

var logDir = Path.Combine(convDir, "content_log");
if (!Directory.Exists(logDir))
Expand Down Expand Up @@ -120,11 +116,7 @@ public void SaveConversationStateLog(ConversationStateLogModel log)
log.MessageId = log.MessageId.IfNullOrEmptyAs(Guid.NewGuid().ToString());

var convDir = FindConversationDirectory(log.ConversationId);
if (string.IsNullOrEmpty(convDir))
{
convDir = Path.Combine(_dbSettings.FileRepository, _conversationSettings.DataDir, log.ConversationId);
Directory.CreateDirectory(convDir);
}
if (string.IsNullOrEmpty(convDir)) return;

var logDir = Path.Combine(convDir, "state_log");
if (!Directory.Exists(logDir))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ private async Task DoWork(CancellationToken stoppingToken)
try
{
await CleanIdleConversationsAsync();
await CloseIdleConversationsAsync(TimeSpan.FromMinutes(10));
}
catch (Exception ex)
{
Expand All @@ -54,37 +53,6 @@ public override async Task StopAsync(CancellationToken stoppingToken)
await base.StopAsync(stoppingToken);
}

private async Task CloseIdleConversationsAsync(TimeSpan conversationIdleTimeout)
{
using var scope = _services.CreateScope();
var conversationService = scope.ServiceProvider.GetRequiredService<IConversationService>();
var hooks = scope.ServiceProvider.GetServices<IConversationHook>()
.OrderBy(x => x.Priority)
.ToList();
var moment = DateTime.UtcNow.Add(-conversationIdleTimeout);
var conversations = (await conversationService.GetLastConversations()).Where(c => c.CreatedTime <= moment);
foreach (var conversation in conversations)
{
try
{
var response = new RoleDialogModel(AgentRole.Assistant, "End the conversation due to timeout.")
{
StopCompletion = true,
FunctionName = "conversation_end"
};

foreach (var hook in hooks)
{
await hook.OnConversationEnding(response);
}
}
catch (Exception ex)
{
_logger.LogError(ex, $"Error occurred closing conversation #{conversation.Id}.");
}
}
}

private async Task CleanIdleConversationsAsync()
{
using var scope = _services.CreateScope();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BotSharp.Abstraction.Routing;
using BotSharp.Abstraction.Users.Models;

namespace BotSharp.OpenAPI.Controllers;
Expand Down Expand Up @@ -161,6 +162,9 @@ public async Task<ChatResponseModel> SendMessage([FromRoute] string agentId,
}

var inputMsg = new RoleDialogModel(AgentRole.User, input.Text);
var routing = _services.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(conversationId, inputMsg.MessageId);

conv.SetConversationId(conversationId, input.States);
conv.States.SetState("channel", input.Channel)
.SetState("provider", input.Provider)
Expand Down
4 changes: 2 additions & 2 deletions src/Plugins/BotSharp.Plugin.ChatHub/Hooks/StreamingLogHook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ private string BuildContentLog(ContentLogInputModel input)
Role = input.Message.Role,
Content = input.Log,
Source = input.Source,
CreateTime = DateTime.UtcNow
CreateTime = input.Message.CreatedAt
};

var json = JsonSerializer.Serialize(output, _options.JsonSerializerOptions);
Expand All @@ -367,7 +367,7 @@ private string BuildStateLog(string conversationId, Dictionary<string, string> s
ConversationId = conversationId,
MessageId = message.MessageId,
States = states,
CreateTime = DateTime.UtcNow
CreateTime = message.CreatedAt
};

var convSettings = _services.GetRequiredService<ConversationSetting>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ namespace BotSharp.Plugin.MongoStorage.Models;
public class StateMongoElement
{
public string Key { get; set; }
public bool Versioning { get; set; }
public List<StateValueMongoElement> Values { get; set; }

public static StateMongoElement ToMongoElement(StateKeyValue state)
{
return new StateMongoElement
{
Key = state.Key,
Versioning = state.Versioning,
Values = state.Values?.Select(x => StateValueMongoElement.ToMongoElement(x))?.ToList() ?? new List<StateValueMongoElement>()
};
}
Expand All @@ -21,6 +23,7 @@ public static StateKeyValue ToDomainElement(StateMongoElement state)
return new StateKeyValue
{
Key = state.Key,
Versioning = state.Versioning,
Values = state.Values?.Select(x => StateValueMongoElement.ToDomainElement(x))?.ToList() ?? new List<StateValue>()
};
}
Expand All @@ -29,13 +32,17 @@ public static StateKeyValue ToDomainElement(StateMongoElement state)
public class StateValueMongoElement
{
public string Data { get; set; }
public string MessageId { get; set; }
public bool Active { get; set; }
public DateTime UpdateTime { get; set; }

public static StateValueMongoElement ToMongoElement(StateValue element)
{
return new StateValueMongoElement
{
Data = element.Data,
MessageId = element.MessageId,
Active = element.Active,
UpdateTime = element.UpdateTime
};
}
Expand All @@ -45,6 +52,8 @@ public static StateValue ToDomainElement(StateValueMongoElement element)
return new StateValue
{
Data = element.Data,
MessageId = element.MessageId,
Active = element.Active,
UpdateTime = element.UpdateTime
};
}
Expand Down
Loading