Skip to content

Commit c23e655

Browse files
authored
feat(agent): shared state, allow to track conversations globally (#148)
* feat(agent): shared state, allow to track conversations globally Signed-off-by: Ettore Di Giacinto <[email protected]> * Cleanup Signed-off-by: Ettore Di Giacinto <[email protected]> * track conversations initiated by the bot Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 2b07dd7 commit c23e655

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+290
-316
lines changed

core/action/custom.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (a *CustomAction) Plannable() bool {
8181
return true
8282
}
8383

84-
func (a *CustomAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
84+
func (a *CustomAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
8585
v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"]))
8686
if err != nil {
8787
return types.ActionResult{}, err

core/action/custom_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ return []string{"foo"}
7676
Description: "A test action",
7777
}))
7878

79-
runResult, err := customAction.Run(context.Background(), types.ActionParams{
79+
runResult, err := customAction.Run(context.Background(), nil, types.ActionParams{
8080
"Foo": "bar",
8181
})
8282
Expect(err).ToNot(HaveOccurred())

core/action/goal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type GoalResponse struct {
2121
Achieved bool `json:"achieved"`
2222
}
2323

24-
func (a *GoalAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
24+
func (a *GoalAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2525
return types.ActionResult{}, nil
2626
}
2727

core/action/intention.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type IntentResponse struct {
2222
Reasoning string `json:"reasoning"`
2323
}
2424

25-
func (a *IntentAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
25+
func (a *IntentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2626
return types.ActionResult{}, nil
2727
}
2828

core/action/newconversation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type ConversationActionResponse struct {
1919
Message string `json:"message"`
2020
}
2121

22-
func (a *ConversationAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
22+
func (a *ConversationAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2323
return types.ActionResult{}, nil
2424
}
2525

core/action/noreply.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func NewStop() *StopAction {
1616

1717
type StopAction struct{}
1818

19-
func (a *StopAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
19+
func (a *StopAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2020
return types.ActionResult{}, nil
2121
}
2222

core/action/plan.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type PlanSubtask struct {
3030
Reasoning string `json:"reasoning"`
3131
}
3232

33-
func (a *PlanAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
33+
func (a *PlanAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
3434
return types.ActionResult{}, nil
3535
}
3636

core/action/reasoning.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type ReasoningResponse struct {
2020
Reasoning string `json:"reasoning"`
2121
}
2222

23-
func (a *ReasoningAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
23+
func (a *ReasoningAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2424
return types.ActionResult{}, nil
2525
}
2626

core/action/reply.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type ReplyResponse struct {
2222
Message string `json:"message"`
2323
}
2424

25-
func (a *ReplyAction) Run(context.Context, types.ActionParams) (string, error) {
25+
func (a *ReplyAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (string, error) {
2626
return "no-op", nil
2727
}
2828

core/action/state.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func NewState() *StateAction {
1515

1616
type StateAction struct{}
1717

18-
func (a *StateAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
18+
func (a *StateAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
1919
return types.ActionResult{Result: "internal state has been updated"}, nil
2020
}
2121

core/agent/agent.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ type Agent struct {
4646
newMessagesSubscribers []func(openai.ChatCompletionMessage)
4747

4848
observer Observer
49+
50+
sharedState *types.AgentSharedState
4951
}
5052

5153
type RAGDB interface {
@@ -78,6 +80,7 @@ func New(opts ...Option) (*Agent, error) {
7880
context: types.NewActionContext(ctx, cancel),
7981
newConversations: make(chan openai.ChatCompletionMessage),
8082
newMessagesSubscribers: options.newConversationsSubscribers,
83+
sharedState: types.NewAgentSharedState(options.lastMessageDuration),
8184
}
8285

8386
// Initialize observer if provided
@@ -118,6 +121,10 @@ func New(opts ...Option) (*Agent, error) {
118121
return a, nil
119122
}
120123

124+
func (a *Agent) SharedState() *types.AgentSharedState {
125+
return a.sharedState
126+
}
127+
121128
func (a *Agent) startNewConversationsConsumer() {
122129
go func() {
123130
for {
@@ -294,7 +301,7 @@ func (a *Agent) runAction(job *types.Job, chosenAction types.Action, params type
294301

295302
for _, act := range a.availableActions() {
296303
if act.Definition().Name == chosenAction.Definition().Name {
297-
res, err := act.Run(job.GetContext(), params)
304+
res, err := act.Run(job.GetContext(), a.sharedState, params)
298305
if err != nil {
299306
if obs != nil {
300307
obs.Completion = &types.Completion{

core/agent/agent_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (a *TestAction) Plannable() bool {
4444
return true
4545
}
4646

47-
func (a *TestAction) Run(c context.Context, p types.ActionParams) (types.ActionResult, error) {
47+
func (a *TestAction) Run(c context.Context, sharedState *types.AgentSharedState, p types.ActionParams) (types.ActionResult, error) {
4848
for k, r := range a.response {
4949
if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) {
5050
return types.ActionResult{Result: r}, nil

core/agent/mcp.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (a *mcpAction) Plannable() bool {
3838
return true
3939
}
4040

41-
func (m *mcpAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
41+
func (m *mcpAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
4242
resp, err := m.mcpClient.CallTool(ctx, m.toolName, params)
4343
if err != nil {
4444
xlog.Error("Failed to call tool", "error", err.Error())

core/agent/options.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ type options struct {
6464

6565
observer Observer
6666
parallelJobs int
67+
68+
lastMessageDuration time.Duration
6769
}
6870

6971
func (o *options) SeparatedMultimodalModel() bool {
@@ -151,6 +153,17 @@ func EnableKnowledgeBaseWithResults(results int) Option {
151153
}
152154
}
153155

156+
func WithLastMessageDuration(duration string) Option {
157+
return func(o *options) error {
158+
d, err := time.ParseDuration(duration)
159+
if err != nil {
160+
d = types.DefaultLastMessageDuration
161+
}
162+
o.lastMessageDuration = d
163+
return nil
164+
}
165+
}
166+
154167
func WithParallelJobs(jobs int) Option {
155168
return func(o *options) error {
156169
o.parallelJobs = jobs

core/agent/state.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ import (
1414
// all information that should be displayed to the LLM
1515
// in the prompts
1616
type PromptHUD struct {
17-
Character Character `json:"character"`
17+
Character Character `json:"character"`
1818
CurrentState types.AgentInternalState `json:"current_state"`
19-
PermanentGoal string `json:"permanent_goal"`
20-
ShowCharacter bool `json:"show_character"`
19+
PermanentGoal string `json:"permanent_goal"`
20+
ShowCharacter bool `json:"show_character"`
2121
}
2222

2323
type Character struct {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package conversations_test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/onsi/ginkgo/v2"
7+
. "github.com/onsi/gomega"
8+
)
9+
10+
func TestConversations(t *testing.T) {
11+
RegisterFailHandler(Fail)
12+
RunSpecs(t, "Conversations test suite")
13+
}

services/connectors/conversationstracker.go renamed to core/conversations/conversationstracker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package connectors
1+
package conversations
22

33
import (
44
"fmt"

services/connectors/conversationstracker_test.go renamed to core/conversations/conversationstracker_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
package connectors_test
1+
package conversations_test
22

33
import (
44
"time"
55

6-
"github.com/mudler/LocalAGI/services/connectors"
6+
"github.com/mudler/LocalAGI/core/conversations"
77
. "github.com/onsi/ginkgo/v2"
88
. "github.com/onsi/gomega"
99
"github.com/sashabaranov/go-openai"
1010
)
1111

1212
var _ = Describe("ConversationTracker", func() {
1313
var (
14-
tracker *connectors.ConversationTracker[string]
14+
tracker *conversations.ConversationTracker[string]
1515
duration time.Duration
1616
)
1717

1818
BeforeEach(func() {
1919
duration = 1 * time.Second
20-
tracker = connectors.NewConversationTracker[string](duration)
20+
tracker = conversations.NewConversationTracker[string](duration)
2121
})
2222

2323
It("should initialize with empty conversations", func() {
@@ -81,8 +81,8 @@ var _ = Describe("ConversationTracker", func() {
8181
})
8282

8383
It("should handle different key types", func() {
84-
trackerInt := connectors.NewConversationTracker[int](duration)
85-
trackerInt64 := connectors.NewConversationTracker[int64](duration)
84+
trackerInt := conversations.NewConversationTracker[int](duration)
85+
trackerInt64 := conversations.NewConversationTracker[int64](duration)
8686

8787
message := openai.ChatCompletionMessage{
8888
Role: openai.ChatMessageRoleUser,

core/state/config.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ type AgentConfig struct {
4848

4949
Description string `json:"description" form:"description"`
5050

51-
Model string `json:"model" form:"model"`
52-
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
53-
APIURL string `json:"api_url" form:"api_url"`
54-
APIKey string `json:"api_key" form:"api_key"`
55-
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
56-
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
51+
Model string `json:"model" form:"model"`
52+
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
53+
APIURL string `json:"api_url" form:"api_url"`
54+
APIKey string `json:"api_key" form:"api_key"`
55+
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
56+
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
57+
LastMessageDuration string `json:"last_message_duration" form:"last_message_duration"`
5758

5859
Name string `json:"name" form:"name"`
5960
HUD bool `json:"hud" form:"hud"`
@@ -329,6 +330,14 @@ func NewAgentConfigMeta(
329330
HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses",
330331
Tags: config.Tags{Section: "AdvancedSettings"},
331332
},
333+
{
334+
Name: "last_message_duration",
335+
Label: "Last Message Duration",
336+
Type: "text",
337+
DefaultValue: "5m",
338+
HelpText: "Duration for the last message to be considered in the conversation",
339+
Tags: config.Tags{Section: "AdvancedSettings"},
340+
},
332341
},
333342
MCPServers: []config.Field{
334343
{

core/state/pool.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
462462
}),
463463
WithSystemPrompt(config.SystemPrompt),
464464
WithMultimodalModel(multimodalModel),
465+
WithLastMessageDuration(config.LastMessageDuration),
465466
WithAgentResultCallback(func(state types.ActionState) {
466467
a.Lock()
467468
if _, ok := a.agentStatus[name]; !ok {

core/types/actions.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (a ActionDefinition) ToFunctionDefinition() *openai.FunctionDefinition {
8888

8989
// Actions is something the agent can do
9090
type Action interface {
91-
Run(ctx context.Context, action ActionParams) (ActionResult, error)
91+
Run(ctx context.Context, sharedState *AgentSharedState, action ActionParams) (ActionResult, error)
9292
Definition() ActionDefinition
9393
Plannable() bool
9494
}

core/types/state.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
package types
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
"time"
6+
7+
"github.com/mudler/LocalAGI/core/conversations"
8+
)
49

510
// State is the structure
611
// that is used to keep track of the current state
@@ -20,6 +25,23 @@ type AgentInternalState struct {
2025
Goal string `json:"goal"`
2126
}
2227

28+
const (
29+
DefaultLastMessageDuration = 5 * time.Minute
30+
)
31+
32+
type AgentSharedState struct {
33+
ConversationTracker *conversations.ConversationTracker[string] `json:"conversation_tracker"`
34+
}
35+
36+
func NewAgentSharedState(lastMessageDuration time.Duration) *AgentSharedState {
37+
if lastMessageDuration == 0 {
38+
lastMessageDuration = DefaultLastMessageDuration
39+
}
40+
return &AgentSharedState{
41+
ConversationTracker: conversations.NewConversationTracker[string](lastMessageDuration),
42+
}
43+
}
44+
2345
const fmtT = `=====================
2446
NowDoing: %s
2547
DoingNext: %s

services/actions/browse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func NewBrowse(config map[string]string) *BrowseAction {
1818

1919
type BrowseAction struct{}
2020

21-
func (a *BrowseAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
21+
func (a *BrowseAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2222
result := struct {
2323
URL string `json:"url"`
2424
}{}

services/actions/browseragentrunner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func NewBrowserAgentRunner(config map[string]string, defaultURL string) *Browser
4545
}
4646
}
4747

48-
func (b *BrowserAgentRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
48+
func (b *BrowserAgentRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
4949
result := api.AgentRequest{}
5050
err := params.Unmarshal(&result)
5151
if err != nil {

services/actions/callagents.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ type CallAgentAction struct {
5252
blacklist []string
5353
}
5454

55-
func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
55+
func (a *CallAgentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
5656
result := struct {
5757
AgentName string `json:"agent_name"`
5858
Message string `json:"message"`

services/actions/counter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func NewCounter(config map[string]string) *CounterAction {
2424
}
2525

2626
// Run executes the counter action
27-
func (a *CounterAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
27+
func (a *CounterAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
2828
// Parse parameters
2929
request := struct {
3030
Name string `json:"name"`

services/actions/deepresearchrunner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func NewDeepResearchRunner(config map[string]string, defaultURL string) *DeepRes
4545
}
4646
}
4747

48-
func (d *DeepResearchRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
48+
func (d *DeepResearchRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
4949
result := api.DeepResearchRequest{}
5050
err := params.Unmarshal(&result)
5151
if err != nil {

services/actions/genimage.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type GenImageAction struct {
2929
imageModel string
3030
}
3131

32-
func (a *GenImageAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
32+
func (a *GenImageAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
3333
result := struct {
3434
Prompt string `json:"prompt"`
3535
Size string `json:"size"`

0 commit comments

Comments
 (0)