Skip to content

Commit 7a84777

Browse files
authored
sync: minja (#12739)
* sync: minja google/minja#57 * fix json include
1 parent 3e1d293 commit 7a84777

File tree

2 files changed

+124
-94
lines changed

2 files changed

+124
-94
lines changed

common/minja/chat-template.hpp

+15-7
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,19 @@
99
#pragma once
1010

1111
#include "minja.hpp"
12-
#include <json.hpp>
12+
13+
#include <chrono>
14+
#include <cstddef>
15+
#include <cstdio>
16+
#include <exception>
17+
#include <iomanip>
18+
#include <memory>
19+
#include <sstream>
1320
#include <string>
1421
#include <vector>
1522

23+
#include <json.hpp>
24+
1625
using json = nlohmann::ordered_json;
1726

1827
namespace minja {
@@ -425,7 +434,7 @@ class chat_template {
425434
auto obj = json {
426435
{"tool_calls", tool_calls},
427436
};
428-
if (!content.is_null() && content != "") {
437+
if (!content.is_null() && !content.empty()) {
429438
obj["content"] = content;
430439
}
431440
message["content"] = obj.dump(2);
@@ -435,13 +444,12 @@ class chat_template {
435444
if (polyfill_tool_responses && role == "tool") {
436445
message["role"] = "user";
437446
auto obj = json {
438-
{"tool_response", {
439-
{"content", message.at("content")},
440-
}},
447+
{"tool_response", json::object()},
441448
};
442449
if (message.contains("name")) {
443-
obj["tool_response"]["name"] = message.at("name");
450+
obj["tool_response"]["tool"] = message.at("name");
444451
}
452+
obj["tool_response"]["content"] = message.at("content");
445453
if (message.contains("tool_call_id")) {
446454
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
447455
}
@@ -510,7 +518,7 @@ class chat_template {
510518
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
511519
json messages_with_system = messages;
512520

513-
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
521+
if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") {
514522
std::string existing_system = messages_with_system.at(0).at("content");
515523
messages_with_system[0] = json {
516524
{"role", "system"},

0 commit comments

Comments
 (0)