Skip to content

Python: Improve agent integration tests #11475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import reduce
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, cast

from openai import BadRequestError
from openai._streaming import AsyncStream
from openai.types.responses import ResponseFunctionToolCall
from openai.types.responses.response import Response
Expand All @@ -28,6 +29,7 @@
merge_function_results,
)
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.open_ai.exceptions.content_filter_ai_exception import ContentFilterAIException
from semantic_kernel.contents.annotation_content import AnnotationContent
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import CMC_ITEM_TYPES, ChatMessageContent
Expand All @@ -41,6 +43,7 @@
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.status import Status
from semantic_kernel.exceptions.agent_exceptions import (
AgentExecutionException,
AgentInvokeException,
)
from semantic_kernel.functions.kernel_arguments import KernelArguments
Expand Down Expand Up @@ -481,15 +484,31 @@ async def _get_response(
response_options: dict | None = None,
stream: bool = False,
) -> Response | AsyncStream[ResponseStreamEvent]:
response: Response = await agent.client.responses.create(
input=cls._prepare_chat_history_for_request(chat_history),
instructions=merged_instructions or agent.instructions,
previous_response_id=previous_response_id,
store=store_output_enabled,
tools=tools, # type: ignore
stream=stream,
**response_options,
)
try:
response: Response = await agent.client.responses.create(
input=cls._prepare_chat_history_for_request(chat_history),
instructions=merged_instructions or agent.instructions,
previous_response_id=previous_response_id,
store=store_output_enabled,
tools=tools, # type: ignore
stream=stream,
**response_options,
)
except BadRequestError as ex:
if ex.code == "content_filter":
raise ContentFilterAIException(
f"{type(agent)} encountered a content error",
ex,
) from ex
raise AgentExecutionException(
f"{type(agent)} failed to complete the request",
ex,
) from ex
except Exception as ex:
raise AgentExecutionException(
f"{type(agent)} service failed to complete the request",
ex,
) from ex
if response is None:
raise AgentInvokeException("Response is None")
return response
Expand Down
Empty file.
152 changes: 152 additions & 0 deletions python/tests/integration/agents/agent_test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any, Generic, Protocol, TypeVar

from semantic_kernel.agents.agent import Agent, AgentResponseItem, AgentThread
from semantic_kernel.contents import ChatMessageContent

DEFAULT_MAX_ATTEMPTS = 3
DEFAULT_BACKOFF_SECONDS = 1


class ChatResponseProtocol(Protocol):
"""Represents a single response item returned by the agent."""

@property
def message(self) -> ChatMessageContent: ...

@property
def thread(self) -> AgentThread | None: ...


class ChatAgentProtocol(Protocol):
"""Protocol describing the common agent interface used by the tests."""

async def get_response(
self, messages: str | list[str] | None, thread: object | None = None
) -> ChatResponseProtocol: ...

def invoke(
self, messages: str | list[str] | None, thread: object | None = None
) -> AsyncIterator[ChatResponseProtocol]: ...

def invoke_stream(
self, messages: str | list[str] | None, thread: object | None = None
) -> AsyncIterator[ChatResponseProtocol]: ...


TAgent = TypeVar("TAgent", bound=ChatAgentProtocol)


async def run_with_retry(
coro: Callable[..., Awaitable[Any]],
*args,
attempts: int = DEFAULT_MAX_ATTEMPTS,
backoff_seconds: float = DEFAULT_BACKOFF_SECONDS,
**kwargs,
) -> AgentResponseItem[ChatMessageContent]:
"""
Execute an async callable with retry/backoff logic.

Args:
coro: The async function to call
args: Positional args to pass to the function
attempts: How many times to attempt before giving up
backoff_seconds: The initial backoff in seconds, doubled after each failure
kwargs: Keyword args to pass to the function

Returns:
Whatever the async function returns

Raises:
Exception: If the function fails after the specified number of attempts
"""
delay = backoff_seconds
for attempt in range(1, attempts + 1):
try:
return await coro(*args, **kwargs)
except Exception:
if attempt == attempts:
raise
await asyncio.sleep(delay)
delay *= 2
raise RuntimeError("Unexpected error: run_with_retry exit.")


class AgentTestBase(Generic[TAgent]):
"""Common test base that wraps all agent invocation patterns with retry logic.

Each integration test can inherit from this or use its methods directly.
"""

async def get_response_with_retry(
self,
agent: Agent,
messages: str | list[str] | None,
thread: Any | None = None,
attempts: int = DEFAULT_MAX_ATTEMPTS,
backoff_seconds: float = DEFAULT_BACKOFF_SECONDS,
) -> AgentResponseItem[ChatMessageContent]:
"""Wraps agent.get_response(...) in run_with_retry."""
return await run_with_retry(
agent.get_response, messages=messages, thread=thread, attempts=attempts, backoff_seconds=backoff_seconds
)

async def get_invoke_with_retry(
self,
agent: Any,
messages: str | list[str] | None,
thread: Any | None = None,
attempts: int = DEFAULT_MAX_ATTEMPTS,
backoff_seconds: float = DEFAULT_BACKOFF_SECONDS,
) -> list[AgentResponseItem[ChatMessageContent]]:
"""Wraps agent.invoke(...) in run_with_retry.

Collects generator results in a list before returning them.
"""
return await run_with_retry(
self._collect_from_invoke,
agent,
messages,
thread=thread,
attempts=attempts,
backoff_seconds=backoff_seconds,
)

async def get_invoke_stream_with_retry(
self,
agent: Any,
messages: str | list[str] | None,
thread: Any | None = None,
attempts: int = DEFAULT_MAX_ATTEMPTS,
backoff_seconds: float = DEFAULT_BACKOFF_SECONDS,
) -> list[AgentResponseItem[ChatMessageContent]]:
"""Wraps agent.invoke_stream(...) in run_with_retry.

Collects streaming results in a list before returning them."""
return await run_with_retry(
self._collect_from_invoke_stream,
agent,
messages,
thread=thread,
attempts=attempts,
backoff_seconds=backoff_seconds,
)

async def _collect_from_invoke(
self, agent: Agent, messages: str | list[str] | None, thread: Any | None = None
) -> list[AgentResponseItem[ChatMessageContent]]:
results: list[AgentResponseItem[ChatMessageContent]] = []
async for response in agent.invoke(messages=messages, thread=thread):
results.append(response)
return results

async def _collect_from_invoke_stream(
self, agent: Agent, messages: str | list[str] | None, thread: Any | None = None
) -> list[AgentResponseItem[ChatMessageContent]]:
results: list[AgentResponseItem[ChatMessageContent]] = []
async for response in agent.invoke_stream(messages=messages, thread=thread):
results.append(response)
return results
Loading
Loading