Skip to content

Commit bbd78d0

Browse files
Adopt new function calling that supports Groq and other models (#569)
Co-authored-by: Hejia Zhang <[email protected]>
1 parent 7678afe commit bbd78d0

File tree

5 files changed

+31
-14
lines changed

5 files changed

+31
-14
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ classifiers = [
2828
"Programming Language :: Python :: 3.11",
2929
]
3030
dependencies = [
31-
"openai >= 1.6.1, < 2.0.0",
31+
"openai >= 1.34.0, < 2.0.0",
3232
"typer >= 0.7.0, < 1.0.0",
3333
"click >= 7.1.1, < 9.0.0",
3434
"rich >= 13.1.0, < 14.0.0",

sgpt/function.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
from abc import ABCMeta
44
from pathlib import Path
5-
from typing import Any, Callable, Dict, List
5+
from typing import Any, Callable, Dict, List, Union
66

77
from .config import cfg
88

@@ -59,4 +59,15 @@ def get_function(name: str) -> Callable[..., Any]:
5959

6060

6161
def get_openai_schemas() -> List[Dict[str, Any]]:
62-
return [function.openai_schema for function in functions]
62+
transformed_schemas = []
63+
for function in functions:
64+
schema = {
65+
"type": "function",
66+
"function": {
67+
"name": function.openai_schema["name"],
68+
"description": function.openai_schema.get("description", ""),
69+
"parameters": function.openai_schema.get("parameters", {}),
70+
},
71+
}
72+
transformed_schemas.append(schema)
73+
return transformed_schemas

sgpt/handlers/handler.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..role import DefaultRoles, SystemRole
1010

1111
completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None]
12+
1213
base_url = cfg.get("API_BASE_URL")
1314
use_litellm = cfg.get("USE_LITELLM") == "true"
1415
additional_kwargs = {
@@ -89,36 +90,43 @@ def get_completion(
8990
messages: List[Dict[str, Any]],
9091
functions: Optional[List[Dict[str, str]]],
9192
) -> Generator[str, None, None]:
93+
9294
name = arguments = ""
9395
is_shell_role = self.role.name == DefaultRoles.SHELL.value
9496
is_code_role = self.role.name == DefaultRoles.CODE.value
9597
is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value
9698
if is_shell_role or is_code_role or is_dsc_shell_role:
9799
functions = None
98100

101+
if functions:
102+
additional_kwargs["tool_choice"] = "auto"
103+
additional_kwargs["tools"] = functions
104+
additional_kwargs["parallel_tool_calls"] = False
105+
99106
response = completion(
100107
model=model,
101108
temperature=temperature,
102109
top_p=top_p,
103110
messages=messages,
104-
functions=functions,
105111
stream=True,
106112
**additional_kwargs,
107113
)
108114

109115
try:
110116
for chunk in response:
111117
delta = chunk.choices[0].delta
118+
112119
# LiteLLM uses dict instead of Pydantic object like OpenAI does.
113-
function_call = (
114-
delta.get("function_call") if use_litellm else delta.function_call
120+
tool_calls = (
121+
delta.get("tool_calls") if use_litellm else delta.tool_calls
115122
)
116-
if function_call:
117-
if function_call.name:
118-
name = function_call.name
119-
if function_call.arguments:
120-
arguments += function_call.arguments
121-
if chunk.choices[0].finish_reason == "function_call":
123+
if tool_calls:
124+
for tool_call in tool_calls:
125+
if tool_call.function.name:
126+
name = tool_call.function.name
127+
if tool_call.function.arguments:
128+
arguments += tool_call.function.arguments
129+
if chunk.choices[0].finish_reason == "tool_calls":
122130
yield from self.handle_function_call(messages, name, arguments)
123131
yield from self.get_completion(
124132
model=model,

tests/test_default.py

-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def test_llm_options(completion):
209209
model=args["--model"],
210210
temperature=args["--temperature"],
211211
top_p=args["--top-p"],
212-
functions=None,
213212
)
214213
completion.assert_called_once_with(**expected_args)
215214
assert result.exit_code == 0

tests/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def comp_args(role, prompt, **kwargs):
5454
"model": cfg.get("DEFAULT_MODEL"),
5555
"temperature": 0.0,
5656
"top_p": 1.0,
57-
"functions": None,
5857
"stream": True,
5958
**kwargs,
6059
}

0 commit comments

Comments
 (0)