Skip to content

Commit 88de2ea

Browse files
authored
feat(llama.cpp): add support for audio input (#5466)
* feat(llama.cpp): add support for audio input Signed-off-by: Ettore Di Giacinto <[email protected]> * Adapt tests Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 9650d49 commit 88de2ea

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

backend/cpp/llama/grpc-server.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ json parse_options(bool streaming, const backend::PredictOptions* predict)
133133
});
134134
}
135135

136+
// for each audio in the request, add the audio data
137+
for (int i = 0; i < predict->audios_size(); i++) {
138+
data["audio_data"].push_back(json
139+
{
140+
{"id", i},
141+
{"data", predict->audios(i)},
142+
});
143+
}
144+
136145
data["stop"] = predict->stopprompts();
137146
// data["n_probs"] = predict->nprobs();
138147
//TODO: images,
@@ -406,6 +415,16 @@ class BackendServiceImpl final : public backend::Backend::Service {
406415
}
407416
}
408417

418+
const auto &audio_data = data.find("audio_data");
419+
if (audio_data != data.end() && audio_data->is_array())
420+
{
421+
for (const auto &audio : *audio_data)
422+
{
423+
auto decoded_data = base64_decode(audio["data"].get<std::string>());
424+
files.push_back(decoded_data);
425+
}
426+
}
427+
409428
// process files
410429
mtmd::bitmaps bitmaps;
411430
const bool has_mtmd = ctx_server.mctx != nullptr;
@@ -416,10 +435,10 @@ class BackendServiceImpl final : public backend::Backend::Service {
416435
for (auto & file : files) {
417436
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
418437
if (!bmp.ptr) {
419-
throw std::runtime_error("Failed to load image");
438+
throw std::runtime_error("Failed to load image/audio");
420439
}
421440
// calculate bitmap hash (for KV caching)
422-
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
441+
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
423442
bmp.set_id(hash.c_str());
424443
bitmaps.entries.push_back(std::move(bmp));
425444
}
@@ -588,6 +607,16 @@ class BackendServiceImpl final : public backend::Backend::Service {
588607
}
589608
}
590609

610+
const auto &audio_data = data.find("audio_data");
611+
if (audio_data != data.end() && audio_data->is_array())
612+
{
613+
for (const auto &audio : *audio_data)
614+
{
615+
auto decoded_data = base64_decode(audio["data"].get<std::string>());
616+
files.push_back(decoded_data);
617+
}
618+
}
619+
591620
// process files
592621
mtmd::bitmaps bitmaps;
593622
const bool has_mtmd = ctx_server.mctx != nullptr;
@@ -598,10 +627,10 @@ class BackendServiceImpl final : public backend::Backend::Service {
598627
for (auto & file : files) {
599628
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
600629
if (!bmp.ptr) {
601-
throw std::runtime_error("Failed to load image");
630+
throw std::runtime_error("Failed to load image/audio");
602631
}
603632
// calculate bitmap hash (for KV caching)
604-
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
633+
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
605634
bmp.set_id(hash.c_str());
606635
bitmaps.entries.push_back(std::move(bmp));
607636
}

core/http/middleware/request.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *sch
308308
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
309309
vidIndex++
310310
nrOfVideosInMessage++
311-
case "audio_url", "audio":
311+
case "audio_url", "audio", "input_audio":
312312
// Decode content as base64 either if it's an URL or base64 text
313313
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
314314
if err != nil {

pkg/templates/multimodal.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ type MultimodalContent struct {
2222
}
2323

2424
// https://github.com/ggml-org/llama.cpp/blob/be1d4a13db26750fac702ceb3af88ae4f39dc9f4/tools/mtmd/mtmd.h#L42
25-
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}<__image__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
25+
// from <__image__> to <__media__> https://github.com/ggml-org/llama.cpp/blob/79c137f77677b3c8ee3c60a7da033721b938399a/tools/mtmd/mtmd.cpp#L83
26+
const DefaultMultiModalTemplate = "{{ range .Audio }}<__media__>{{end}}{{ range .Images }}<__media__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
2627

2728
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
2829
if templateString == "" {

pkg/templates/multimodal_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ var _ = Describe("EvaluateTemplate", func() {
2020
VideosInMessage: 0,
2121
}, "bar")
2222
Expect(err).NotTo(HaveOccurred())
23-
Expect(result).To(Equal("<__image__>bar"))
23+
Expect(result).To(Equal("<__media__>bar"))
2424
})
2525

2626
It("should handle messages with more images correctly", func() {
@@ -33,7 +33,7 @@ var _ = Describe("EvaluateTemplate", func() {
3333
VideosInMessage: 0,
3434
}, "bar")
3535
Expect(err).NotTo(HaveOccurred())
36-
Expect(result).To(Equal("<__image__><__image__>bar"))
36+
Expect(result).To(Equal("<__media__><__media__>bar"))
3737
})
3838
It("should handle messages with more images correctly", func() {
3939
result, err := TemplateMultiModal("", MultiModalOptions{
@@ -45,7 +45,7 @@ var _ = Describe("EvaluateTemplate", func() {
4545
VideosInMessage: 0,
4646
}, "bar")
4747
Expect(err).NotTo(HaveOccurred())
48-
Expect(result).To(Equal("[audio-0]<__image__><__image__>bar"))
48+
Expect(result).To(Equal("<__media__><__media__><__media__>bar"))
4949
})
5050
It("should handle messages with more images correctly", func() {
5151
result, err := TemplateMultiModal("", MultiModalOptions{
@@ -57,7 +57,7 @@ var _ = Describe("EvaluateTemplate", func() {
5757
VideosInMessage: 0,
5858
}, "bar")
5959
Expect(err).NotTo(HaveOccurred())
60-
Expect(result).To(Equal("[audio-0]<__image__>bar"))
60+
Expect(result).To(Equal("<__media__><__media__>bar"))
6161
})
6262
It("should handle messages with more images correctly", func() {
6363
result, err := TemplateMultiModal("", MultiModalOptions{

0 commit comments

Comments
 (0)