Skip to content

Commit 3c03727

Browse files
TaoChenOSUeavanvalkenburg
authored andcommitted
Python: Fix agent samples (microsoft#11278)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> This [sample](https://github.com/microsoft/semantic-kernel/blob/main/python/samples/concepts/agents/chat_completion_agent/chat_completion_agent_function_termination.py) was throwing an error due to duplicating tool messages in the chat history. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> Fix a few agent samples. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 6c532a2 commit 3c03727

File tree

6 files changed

+39
-54
lines changed

6 files changed

+39
-54
lines changed

python/samples/concepts/agents/chat_completion_agent/chat_completion_agent_function_termination.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ async def main():
9999
print("================================")
100100

101101
# 4. Print out the chat history to view the different types of messages
102-
chat_history = await thread.get_messages()
103-
for message in chat_history.messages:
102+
async for message in thread.get_messages():
104103
_write_content(message)
105104

106105
"""

python/samples/concepts/agents/chat_completion_agent/chat_completion_agent_summary_history_reducer_single_agent.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,9 @@
33
import asyncio
44
import logging
55

6-
from semantic_kernel.agents import (
7-
ChatCompletionAgent,
8-
ChatHistoryAgentThread,
9-
)
6+
from semantic_kernel.agents import ChatCompletionAgent, ChatHistoryAgentThread
107
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
11-
from semantic_kernel.contents import (
12-
ChatHistorySummarizationReducer,
13-
)
8+
from semantic_kernel.contents import ChatHistorySummarizationReducer
149

1510
"""
1611
The following sample demonstrates how to implement a truncation chat
@@ -55,14 +50,13 @@ async def main():
5550
# Attempt reduction
5651
is_reduced = await thread.reduce()
5752
if is_reduced:
58-
print(f"@ History reduced to {len(thread.messages)} messages.")
53+
print(f"@ History reduced to {len(thread)} messages.")
5954

60-
print(f"@ Message Count: {len(thread.messages)}\n")
55+
print(f"@ Message Count: {len(thread)}\n")
6156

6257
# If reduced, print summary if present
6358
if is_reduced:
64-
chat_history = await thread.get_messages()
65-
for msg in chat_history.messages:
59+
async for msg in thread.get_messages():
6660
if msg.metadata and msg.metadata.get("__summary__"):
6761
print(f"\tSummary: {msg.content}")
6862
break

python/semantic_kernel/agents/autogen/autogen_conversable_agent.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,8 @@ async def get_response(
157157
)
158158
assert thread.id is not None # nosec
159159

160-
chat_history = ChatHistory()
161-
async for message in thread.get_messages():
162-
chat_history.add_message(message)
163-
164160
reply = await self.conversable_agent.a_generate_reply(
165-
messages=[message.to_dict() for message in chat_history.messages],
161+
messages=[message.to_dict() async for message in thread.get_messages()],
166162
**kwargs,
167163
)
168164

@@ -214,17 +210,14 @@ async def invoke(
214210
)
215211
assert thread.id is not None # nosec
216212

217-
chat_history = ChatHistory()
218-
async for message in thread.get_messages():
219-
chat_history.add_message(message)
220-
221213
if recipient is not None:
222214
if not isinstance(recipient, AutoGenConversableAgent):
223215
raise AgentInvokeException(
224216
f"Invalid recipient type: {type(recipient)}. "
225217
"Recipient must be an instance of AutoGenConversableAgent."
226218
)
227219

220+
messages = [message async for message in thread.get_messages()]
228221
chat_result = await self.conversable_agent.a_initiate_chat(
229222
recipient=recipient.conversable_agent,
230223
clear_history=clear_history,
@@ -233,7 +226,7 @@ async def invoke(
233226
max_turns=max_turns,
234227
summary_method=summary_method,
235228
summary_args=summary_args,
236-
message=chat_history.messages[-1].content, # type: ignore
229+
message=messages[-1].content, # type: ignore
237230
**kwargs,
238231
)
239232

@@ -248,7 +241,7 @@ async def invoke(
248241
)
249242
else:
250243
reply = await self.conversable_agent.a_generate_reply(
251-
messages=[message.to_dict() for message in chat_history.messages],
244+
messages=[message.to_dict() async for message in thread.get_messages()],
252245
)
253246

254247
logger.info("Called AutoGenConversableAgent.a_generate_reply.")

python/semantic_kernel/agents/chat_completion/chat_completion_agent.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,9 @@ async def create_channel(
231231
if thread.id is None:
232232
await thread.create()
233233

234-
chat_history = ChatHistory()
235-
async for message in thread.get_messages():
236-
chat_history.add_message(message)
234+
messages = [message async for message in thread.get_messages()]
237235

238-
return ChatHistoryChannel(messages=chat_history.messages, thread=thread)
236+
return ChatHistoryChannel(messages=messages, thread=thread)
239237

240238
@trace_agent_get_response
241239
@override
@@ -280,9 +278,7 @@ async def get_response(
280278
if not responses:
281279
raise AgentInvokeException("No response from agent.")
282280

283-
response = responses[-1]
284-
await thread.on_new_message(response)
285-
return AgentResponseItem(message=response, thread=thread)
281+
return AgentResponseItem(message=responses[-1], thread=thread)
286282

287283
@trace_agent_invocation
288284
@override
@@ -321,7 +317,6 @@ async def invoke(
321317
chat_history.add_message(message)
322318

323319
async for response in self._inner_invoke(thread, chat_history, arguments, kernel, **kwargs):
324-
await thread.on_new_message(response)
325320
yield AgentResponseItem(message=response, thread=thread)
326321

327322
@trace_agent_invocation
@@ -411,6 +406,9 @@ async def invoke_stream(
411406

412407
await self._capture_mutated_messages(agent_chat_history, message_count_before_completion, thread)
413408
if role != AuthorRole.TOOL:
409+
# Tool messages will be automatically added to the chat history by the auto function invocation loop
410+
# if it's the response (i.e. terminated by a filter), thus we need to avoid notifying the thread about
411+
# them multiple times.
414412
await thread.on_new_message(
415413
ChatMessageContent(
416414
role=role if role else AuthorRole.ASSISTANT, content="".join(response_builder), name=self.name
@@ -468,6 +466,11 @@ async def _inner_invoke(
468466

469467
for response in responses:
470468
response.name = self.name
469+
if response.role != AuthorRole.TOOL:
470+
# Tool messages will be automatically added to the chat history by the auto function invocation loop
471+
# if it's the response (i.e. terminated by a filter),, thus we need to avoid notifying the thread about
472+
# them multiple times.
473+
await thread.on_new_message(response)
471474
yield response
472475

473476
async def _prepare_agent_chat_history(

python/semantic_kernel/core_plugins/math_plugin.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,27 @@ class MathPlugin:
1919
@kernel_function(name="Add")
2020
def add(
2121
self,
22-
input: Annotated[int, "the first number to add"],
23-
amount: Annotated[int, "the second number to add"],
24-
) -> Annotated[int, "the output is a number"]:
22+
input: Annotated[int | str, "The first number to add"],
23+
amount: Annotated[int | str, "The second number to add"],
24+
) -> Annotated[int, "The result"]:
2525
"""Returns the Addition result of the values provided."""
2626
if isinstance(input, str):
2727
input = int(input)
2828
if isinstance(amount, str):
2929
amount = int(amount)
30-
return MathPlugin.add_or_subtract(input, amount, add=True)
30+
31+
return input + amount
3132

3233
@kernel_function(name="Subtract")
3334
def subtract(
3435
self,
35-
input: Annotated[int, "the first number"],
36-
amount: Annotated[int, "the number to subtract"],
37-
) -> int:
36+
input: Annotated[int | str, "The number to subtract from"],
37+
amount: Annotated[int | str, "The number to subtract"],
38+
) -> Annotated[int, "The result"]:
3839
"""Returns the difference of numbers provided."""
3940
if isinstance(input, str):
4041
input = int(input)
4142
if isinstance(amount, str):
4243
amount = int(amount)
43-
return MathPlugin.add_or_subtract(input, amount, add=False)
4444

45-
@staticmethod
46-
def add_or_subtract(input: int, amount: int, add: bool) -> int:
47-
"""Helper function to perform addition or subtraction based on the add flag."""
48-
return input + amount if add else input - amount
45+
return input - amount

python/tests/unit/agents/chat_completion/test_chat_completion_agent.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ async def test_invoke(kernel_with_ai_service: tuple[Kernel, ChatCompletionClient
180180
assert messages[0].message.content == "Processed Message"
181181

182182

183-
async def test_invoke_tool_call_added(kernel_with_ai_service: tuple[Kernel, ChatCompletionClientBase]):
183+
async def test_invoke_tool_call_not_added(kernel_with_ai_service: tuple[Kernel, ChatCompletionClientBase]):
184184
kernel, mock_ai_service_client = kernel_with_ai_service
185185
agent = ChatCompletionAgent(
186186
kernel=kernel,
@@ -211,15 +211,14 @@ async def mock_get_chat_message_contents(
211211
assert messages[1].message.content == "Processed Message 2"
212212

213213
thread: ChatHistoryAgentThread = messages[-1].thread
214-
history = ChatHistory()
215-
async for message in thread.get_messages():
216-
history.add_message(message)
217-
218-
assert len(history.messages) == 5
219-
assert history.messages[1].content == "Processed Message 1"
220-
assert history.messages[2].content == "Processed Message 2"
221-
assert history.messages[1].name == "TestAgent"
222-
assert history.messages[2].name == "TestAgent"
214+
thread_messages = [message async for message in thread.get_messages()]
215+
216+
assert len(thread_messages) == 4
217+
assert thread_messages[0].content == "test"
218+
assert thread_messages[1].content == "Processed Message 1"
219+
assert thread_messages[2].content == "Processed Message 2"
220+
assert thread_messages[1].name == "TestAgent"
221+
assert thread_messages[2].name == "TestAgent"
223222

224223

225224
async def test_invoke_no_service_throws(kernel: Kernel):

0 commit comments

Comments
 (0)