Skip to content

Commit 90bd5ed

Browse files
am831awharrison-28
authored andcommitted
Python: Add Google PaLM connector with text completion and example file (microsoft#2076)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Implementation of Google PaLM connector with text completion and an example file to demonstrate its functionality. Closes microsoft#1979 ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> 1. Implemented Google Palm connector with text completion 2. Added example file to ```python/samples/kernel-syntax-examples``` 3. Added integration tests with different inputs to kernel.run_async 4. Added unit tests to ensure successful initialization of the class and successful API calls 5. 3 optional arguments (top_k, safety_settings, client) for google.generativeai.generate_text were not included. See more information about the function and its arguments: https://developers.generativeai.google/api/python/google/generativeai/generate_text I also opened a PR for text embedding and chat completion microsoft#2258 ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#dev-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 Currently no warnings, there was 1 warning when first installing genai with `poetry add google.generativeai==v0.1.0rc2` from within poetry shell: "The locked version 0.1.0rc2 for google-generativeai is a yanked version. Reason for being yanked: Release is marked as supporting Py3.8, but in practice it requires 3.9". We would need to require later versions of python to fix it. --------- Co-authored-by: Abby Harrison <[email protected]> Co-authored-by: Abby Harrison <[email protected]>
1 parent 04a719a commit 90bd5ed

File tree

12 files changed

+698
-80
lines changed

12 files changed

+698
-80
lines changed

python/.env.example

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ AZURE_COGNITIVE_SEARCH_ADMIN_KEY=""
88
PINECONE_API_KEY=""
99
PINECONE_ENVIRONMENT=""
1010
POSTGRES_CONNECTION_STRING=""
11-
GOOGLE_API_KEY=""
11+
GOOGLE_PALM_API_KEY=""
1212
GOOGLE_SEARCH_ENGINE_ID=""

python/poetry.lock

+275-79
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/pyproject.toml

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ pytest = "7.4.0"
2525
ruff = "0.0.283"
2626
pytest-asyncio = "0.21.1"
2727

28+
[tool.poetry.group.google_palm.dependencies]
29+
google-generativeai = { version = "^0.1.0", markers = "python_version >= '3.9'" }
30+
grpcio-status = { version = "^1.53.0", markers = "python_version >= '3.9'" }
31+
2832
[tool.poetry.group.hugging_face.dependencies]
2933
transformers = "^4.28.1"
3034
sentence-transformers = "^2.2.2"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
5+
import semantic_kernel as sk
6+
import semantic_kernel.connectors.ai.google_palm as sk_gp
7+
from semantic_kernel.connectors.ai.complete_request_settings import (
8+
CompleteRequestSettings,
9+
)
10+
11+
12+
async def text_completion_example_complete_async(kernel, api_key, user_mssg, settings):
13+
"""
14+
Complete a text prompt using the Google PaLM model and print the results.
15+
"""
16+
palm_text_completion = sk_gp.GooglePalmTextCompletion(
17+
"models/text-bison-001", api_key
18+
)
19+
kernel.add_text_completion_service("models/text-bison-001", palm_text_completion)
20+
answer = await palm_text_completion.complete_async(user_mssg, settings)
21+
return answer
22+
23+
24+
async def main() -> None:
25+
kernel = sk.Kernel()
26+
apikey = sk.google_palm_settings_from_dot_env()
27+
settings = CompleteRequestSettings()
28+
29+
user_mssg1 = (
30+
"Sam has three boxes, each containing a certain number of coins. "
31+
"The first box has twice as many coins as the second box, and the second "
32+
"box has three times as many coins as the third box. Together, the three "
33+
"boxes have 98 coins in total. How many coins are there in each box? "
34+
"Think about it step by step, and show your work."
35+
)
36+
response = await text_completion_example_complete_async(
37+
kernel, apikey, user_mssg1, settings
38+
)
39+
print(f"User:> {user_mssg1}\n\nChatBot:> {response}\n")
40+
# Use temperature to influence the variance of the responses
41+
settings.number_of_responses = 3
42+
settings.temperature = 1
43+
user_mssg2 = (
44+
"I need a concise answer. A common method for traversing a binary tree is"
45+
)
46+
response = await text_completion_example_complete_async(
47+
kernel, apikey, user_mssg2, settings
48+
)
49+
print(f"User:> {user_mssg2}\n\nChatBot:> {response}")
50+
return
51+
52+
53+
if __name__ == "__main__":
54+
asyncio.run(main())

python/semantic_kernel/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from semantic_kernel.utils.null_logger import NullLogger
1717
from semantic_kernel.utils.settings import (
1818
azure_openai_settings_from_dot_env,
19+
google_palm_settings_from_dot_env,
1920
openai_settings_from_dot_env,
2021
pinecone_settings_from_dot_env,
2122
postgres_settings_from_dot_env,
@@ -28,6 +29,7 @@
2829
"azure_openai_settings_from_dot_env",
2930
"postgres_settings_from_dot_env",
3031
"pinecone_settings_from_dot_env",
32+
"google_palm_settings_from_dot_env",
3133
"PromptTemplateConfig",
3234
"PromptTemplate",
3335
"ChatPromptTemplate",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from semantic_kernel.connectors.ai.google_palm.services.gp_text_completion import (
4+
GooglePalmTextCompletion,
5+
)
6+
7+
__all__ = ["GooglePalmTextCompletion"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from typing import List, Union
4+
5+
import google.generativeai as palm
6+
7+
from semantic_kernel.connectors.ai.ai_exception import AIException
8+
from semantic_kernel.connectors.ai.complete_request_settings import (
9+
CompleteRequestSettings,
10+
)
11+
from semantic_kernel.connectors.ai.text_completion_client_base import (
12+
TextCompletionClientBase,
13+
)
14+
15+
16+
class GooglePalmTextCompletion(TextCompletionClientBase):
17+
_model_id: str
18+
_api_key: str
19+
20+
def __init__(self, model_id: str, api_key: str) -> None:
21+
"""
22+
Initializes a new instance of the GooglePalmTextCompletion class.
23+
24+
Arguments:
25+
model_id {str} -- GooglePalm model name, see
26+
https://developers.generativeai.google/models/language
27+
api_key {str} -- GooglePalm API key, see
28+
https://developers.generativeai.google/products/palm
29+
"""
30+
if not api_key:
31+
raise ValueError("The Google PaLM API key cannot be `None` or empty`")
32+
33+
self._model_id = model_id
34+
self._api_key = api_key
35+
36+
async def complete_async(
37+
self, prompt: str, request_settings: CompleteRequestSettings
38+
) -> Union[str, List[str]]:
39+
response = await self._send_completion_request(prompt, request_settings)
40+
41+
if request_settings.number_of_responses > 1:
42+
return [candidate["output"] for candidate in response.candidates]
43+
else:
44+
return response.result
45+
46+
async def complete_stream_async(
47+
self, prompt: str, request_settings: CompleteRequestSettings
48+
):
49+
raise NotImplementedError(
50+
"Google Palm API does not currently support streaming"
51+
)
52+
53+
async def _send_completion_request(
54+
self, prompt: str, request_settings: CompleteRequestSettings
55+
):
56+
"""
57+
Completes the given prompt. Returns a single string completion.
58+
Cannot return multiple completions. Cannot return logprobs.
59+
60+
Arguments:
61+
prompt {str} -- The prompt to complete.
62+
request_settings {CompleteRequestSettings} -- The request settings.
63+
64+
Returns:
65+
str -- The completed text.
66+
"""
67+
if not prompt:
68+
raise ValueError("Prompt cannot be `None` or empty")
69+
if request_settings is None:
70+
raise ValueError("Request settings cannot be `None`")
71+
if request_settings.max_tokens < 1:
72+
raise AIException(
73+
AIException.ErrorCodes.InvalidRequest,
74+
"The max tokens must be greater than 0, "
75+
f"but was {request_settings.max_tokens}",
76+
)
77+
try:
78+
palm.configure(api_key=self._api_key)
79+
except Exception as ex:
80+
raise PermissionError(
81+
"Google PaLM service failed to configure. Invalid API key provided.",
82+
ex,
83+
)
84+
try:
85+
response = palm.generate_text(
86+
model=self._model_id,
87+
prompt=prompt,
88+
temperature=request_settings.temperature,
89+
max_output_tokens=request_settings.max_tokens,
90+
stop_sequences=(
91+
request_settings.stop_sequences
92+
if request_settings.stop_sequences is not None
93+
and len(request_settings.stop_sequences) > 0
94+
else None
95+
),
96+
candidate_count=request_settings.number_of_responses,
97+
top_p=request_settings.top_p,
98+
)
99+
except Exception as ex:
100+
raise AIException(
101+
AIException.ErrorCodes.ServiceError,
102+
"Google PaLM service failed to complete the prompt",
103+
ex,
104+
)
105+
return response

python/semantic_kernel/utils/settings.py

+16
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,19 @@ def pinecone_settings_from_dot_env() -> Tuple[str, str]:
7979
assert environment, "Pinecone environment not found in .env file"
8080

8181
return api_key, environment
82+
83+
84+
def google_palm_settings_from_dot_env() -> str:
85+
"""
86+
Reads the Google PaLM API key from the .env file.
87+
88+
Returns:
89+
str: The Google PaLM API key
90+
"""
91+
92+
config = dotenv_values(".env")
93+
api_key = config.get("GOOGLE_PALM_API_KEY", None)
94+
95+
assert api_key is not None, "Google PaLM API key not found in .env file"
96+
97+
return api_key

python/tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,14 @@ def get_oai_config():
3737
api_key, org_id = sk.openai_settings_from_dot_env()
3838

3939
return api_key, org_id
40+
41+
42+
@pytest.fixture(scope="session")
43+
def get_gp_config():
44+
if "Python_Integration_Tests" in os.environ:
45+
api_key = os.environ["GOOGLE_PALM_API_KEY"]
46+
else:
47+
# Load credentials from .env file
48+
api_key = sk.google_palm_settings_from_dot_env()
49+
50+
return api_key

python/tests/integration/completions/conftest.py

+25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
import semantic_kernel.connectors.ai.google_palm as sk_gp
56
import semantic_kernel.connectors.ai.hugging_face as sk_hf
67

78

@@ -149,3 +150,27 @@ def setup_summarize_conversation_using_skill(create_kernel):
149150
John: Yeah, that's a good idea."""
150151

151152
yield kernel, ChatTranscript
153+
154+
155+
@pytest.fixture(scope="module")
156+
def setup_gp_text_completion_function(create_kernel, get_gp_config):
157+
kernel = create_kernel
158+
api_key = get_gp_config
159+
# Configure LLM service
160+
palm_text_completion = sk_gp.GooglePalmTextCompletion(
161+
"models/text-bison-001", api_key
162+
)
163+
kernel.add_text_completion_service("models/text-bison-001", palm_text_completion)
164+
165+
# Define semantic function using SK prompt template language
166+
sk_prompt = "Hello, I like {{$input}}{{$input2}}"
167+
168+
# Create the semantic function
169+
text2text_function = kernel.create_semantic_function(
170+
sk_prompt, max_tokens=25, temperature=0.7, top_p=0.5
171+
)
172+
173+
# User input
174+
simple_input = "sleeping and "
175+
176+
yield kernel, text2text_function, simple_input
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import os
4+
import sys
5+
6+
import pytest
7+
8+
import semantic_kernel as sk
9+
10+
pytestmark = pytest.mark.skipif(
11+
sys.version_info < (3, 9), reason="Google Palm requires Python 3.9 or greater"
12+
)
13+
14+
pytestmark = pytest.mark.skipif(
15+
"Python_Integration_Tests" in os.environ,
16+
reason="Google Palm integration tests are only set up to run locally",
17+
)
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_text2text_generation_input_str(setup_gp_text_completion_function):
22+
kernel, text2text_function, simple_input = setup_gp_text_completion_function
23+
24+
# Complete input string and print
25+
summary = await kernel.run_async(text2text_function, input_str=simple_input)
26+
27+
output = str(summary).strip()
28+
print(f"Completion using input string: '{output}'")
29+
assert len(output) > 0
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_text2text_generation_input_vars(setup_gp_text_completion_function):
34+
kernel, text2text_function, simple_input = setup_gp_text_completion_function
35+
36+
# Complete input as context variable and print
37+
context_vars = sk.ContextVariables(simple_input)
38+
summary = await kernel.run_async(text2text_function, input_vars=context_vars)
39+
40+
output = str(summary).strip()
41+
print(f"Completion using context variables: '{output}'")
42+
assert len(output) > 0
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_text2text_generation_input_context(setup_gp_text_completion_function):
47+
kernel, text2text_function, simple_input = setup_gp_text_completion_function
48+
49+
# Complete input context and print
50+
context = kernel.create_new_context()
51+
context["input"] = simple_input
52+
summary = await kernel.run_async(text2text_function, input_context=context)
53+
54+
output = str(summary).strip()
55+
print(f"Completion using input context: '{output}'")
56+
assert len(output) > 0
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_text2text_generation_input_context_with_vars(
61+
setup_gp_text_completion_function,
62+
):
63+
kernel, text2text_function, simple_input = setup_gp_text_completion_function
64+
65+
# Complete input context with additional variables and print
66+
context = kernel.create_new_context()
67+
context["input"] = simple_input
68+
context_vars = sk.ContextVariables("running and")
69+
summary = await kernel.run_async(
70+
text2text_function, input_context=context, input_vars=context_vars
71+
)
72+
73+
output = str(summary).strip()
74+
print(f"Completion using context and additional variables: '{output}'")
75+
assert len(output) > 0
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_text2text_generation_input_context_with_str(
80+
setup_gp_text_completion_function,
81+
):
82+
kernel, text2text_function, simple_input = setup_gp_text_completion_function
83+
84+
# Complete input context with additional input string and print
85+
context = kernel.create_new_context()
86+
context["input"] = simple_input
87+
summary = await kernel.run_async(
88+
text2text_function, input_context=context, input_str="running and"
89+
)
90+
91+
output = str(summary).strip()
92+
print(f"Completion using context and additional string: '{output}'")
93+
assert len(output) > 0
94+
95+
96+
@pytest.mark.asyncio
97+
async def test_text2text_generation_input_context_with_vars_and_str(
98+
setup_gp_text_completion_function,
99+
):
100+
kernel, text2text_function, simple_input = setup_gp_text_completion_function
101+
102+
# Complete input context with additional variables and string and print
103+
context = kernel.create_new_context()
104+
context["input"] = simple_input
105+
context_vars = sk.ContextVariables(variables={"input2": "running and"})
106+
summary = await kernel.run_async(
107+
text2text_function,
108+
input_context=context,
109+
input_vars=context_vars,
110+
input_str="new text",
111+
)
112+
113+
output = str(summary).strip()
114+
print(
115+
f"Completion using context, additional variables, and additional string: '{output}'"
116+
)
117+
assert len(output) > 0

0 commit comments

Comments
 (0)