-
Notifications
You must be signed in to change notification settings - Fork 1.9k
add Tool.outputSchema and CallToolResult.structuredContent #685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
addce22
c3d4d4f
994ebad
7131de6
ad2ec44
7f36822
06681d6
9db284b
3b28fba
6244899
e7c6727
ecc7146
6d2882c
982f6b0
43ebe80
ad83eea
3261cbc
b16e716
4d327cd
9109577
1746ea1
d3986f2
c2168f2
76d1a7f
b738f1b
6bf73cd
fa20ea6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
import logging | ||
from collections.abc import Awaitable, Callable | ||
from datetime import timedelta | ||
from typing import Any, Protocol | ||
from typing import Any, Protocol, TypeAlias | ||
|
||
import anyio.lowlevel | ||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | ||
from jsonschema import ValidationError, validate | ||
from pydantic import AnyUrl, TypeAdapter | ||
|
||
import mcp.types as types | ||
|
@@ -11,6 +14,8 @@ | |
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder | ||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") | ||
|
||
|
||
|
@@ -44,6 +49,12 @@ async def __call__( | |
) -> None: ... | ||
|
||
|
||
class ToolOutputValidationFnT(Protocol): | ||
async def __call__( | ||
self, request: types.CallToolRequest, result: types.CallToolResult | ||
) -> bool: ... | ||
|
||
|
||
async def _default_message_handler( | ||
message: RequestResponder[types.ServerRequest, types.ClientResult] | ||
| types.ServerNotification | ||
|
@@ -77,6 +88,25 @@ async def _default_logging_callback( | |
pass | ||
|
||
|
||
ToolOutputValidatorProvider: TypeAlias = Callable[ | ||
..., | ||
Awaitable[ToolOutputValidationFnT], | ||
] | ||
|
||
|
||
# this bag of spanners is required in order to | ||
# enable the client session to be parsed to the validator | ||
async def _python_circularity_hell(arg: Any) -> ToolOutputValidationFnT: | ||
# in any sane version of the universe this should never happen | ||
# of course in any sane programming language class circularity | ||
# dependencies shouldn't be this hard to manage | ||
raise RuntimeError( | ||
"Help I'm stuck in python circularity hell, please send biscuits" | ||
) | ||
|
||
|
||
_default_tool_output_validator: ToolOutputValidatorProvider = _python_circularity_hell | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This little dance (while hilariously rendered, kudos) should be avoidable using
|
||
|
||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( | ||
types.ClientResult | types.ErrorData | ||
) | ||
|
@@ -101,6 +131,7 @@ def __init__( | |
logging_callback: LoggingFnT | None = None, | ||
message_handler: MessageHandlerFnT | None = None, | ||
client_info: types.Implementation | None = None, | ||
tool_output_validator_provider: ToolOutputValidatorProvider | None = None, | ||
) -> None: | ||
super().__init__( | ||
read_stream, | ||
|
@@ -114,6 +145,9 @@ def __init__( | |
self._list_roots_callback = list_roots_callback or _default_list_roots_callback | ||
self._logging_callback = logging_callback or _default_logging_callback | ||
self._message_handler = message_handler or _default_message_handler | ||
self._tool_output_validator_provider = ( | ||
tool_output_validator_provider or _default_tool_output_validator | ||
) | ||
|
||
async def initialize(self) -> types.InitializeResult: | ||
sampling = ( | ||
|
@@ -160,6 +194,8 @@ async def initialize(self) -> types.InitializeResult: | |
) | ||
) | ||
|
||
self._tool_output_validator = await self._tool_output_validator_provider(self) | ||
|
||
return result | ||
|
||
async def send_ping(self) -> types.EmptyResult: | ||
|
@@ -281,24 +317,33 @@ async def call_tool( | |
arguments: dict[str, Any] | None = None, | ||
read_timeout_seconds: timedelta | None = None, | ||
progress_callback: ProgressFnT | None = None, | ||
validate_result: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validation should be done or not based on whether the tool has an |
||
) -> types.CallToolResult: | ||
"""Send a tools/call request with optional progress callback support.""" | ||
|
||
return await self.send_request( | ||
types.ClientRequest( | ||
types.CallToolRequest( | ||
method="tools/call", | ||
params=types.CallToolRequestParams( | ||
name=name, | ||
arguments=arguments, | ||
), | ||
) | ||
request = types.CallToolRequest( | ||
method="tools/call", | ||
params=types.CallToolRequestParams( | ||
name=name, | ||
arguments=arguments, | ||
), | ||
) | ||
|
||
result = await self.send_request( | ||
types.ClientRequest(request), | ||
types.CallToolResult, | ||
request_read_timeout_seconds=read_timeout_seconds, | ||
progress_callback=progress_callback, | ||
) | ||
|
||
if validate_result: | ||
valid = await self._tool_output_validator(request, result) | ||
|
||
if not valid: | ||
raise RuntimeError("Server responded with invalid result: " f"{result}") | ||
# not validating or is valid | ||
return result | ||
|
||
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: | ||
"""Send a prompts/list request.""" | ||
return await self.send_request( | ||
|
@@ -418,3 +463,75 @@ async def _received_notification( | |
await self._logging_callback(params) | ||
case _: | ||
pass | ||
|
||
|
||
class NoOpToolOutputValidator(ToolOutputValidationFnT): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure when we'd need this, if we're not running validation in the absence of an output schema |
||
async def __call__( | ||
self, request: types.CallToolRequest, result: types.CallToolResult | ||
) -> bool: | ||
return True | ||
|
||
|
||
class SimpleCachingToolOutputValidator(ToolOutputValidationFnT): | ||
_schema_cache: dict[str, dict[str, Any] | bool] | ||
|
||
def __init__(self, session: ClientSession): | ||
self._session = session | ||
self._schema_cache = {} | ||
self._refresh_cache = True | ||
|
||
async def __call__( | ||
self, request: types.CallToolRequest, result: types.CallToolResult | ||
) -> bool: | ||
if result.isError: | ||
# allow errors to be propagated | ||
return True | ||
else: | ||
if self._refresh_cache: | ||
await self._refresh_schema_cache() | ||
|
||
schema = self._schema_cache.get(request.params.name) | ||
|
||
if schema is None: | ||
raise RuntimeError(f"Unknown tool {request.params.name}") | ||
elif schema is False: | ||
# no schema | ||
logging.debug("No schema found checking structuredContent is empty") | ||
return result.structuredContent is None | ||
else: | ||
try: | ||
# TODO opportunity to build jsonschema.protocol.Validator | ||
# and reuse rather than build every time | ||
validate(result.structuredContent, schema) | ||
return True | ||
except ValidationError as e: | ||
logging.exception(e) | ||
return False | ||
|
||
async def _refresh_schema_cache(self): | ||
cursor = None | ||
first = True | ||
self._schema_cache = {} | ||
while first or cursor is not None: | ||
first = False | ||
tools_result = await self._session.list_tools(cursor) | ||
for tool in tools_result.tools: | ||
# store a flag to be able to later distinguish between | ||
# no schema for tool and unknown tool which can't be verified | ||
schema_or_flag = ( | ||
False if tool.outputSchema is None else tool.outputSchema | ||
) | ||
self._schema_cache[tool.name] = schema_or_flag | ||
cursor = tools_result.nextCursor | ||
continue | ||
|
||
self._refresh_cache = False | ||
|
||
|
||
async def _escape_from_circular_python_hell( | ||
session: ClientSession, | ||
) -> ToolOutputValidationFnT: | ||
return SimpleCachingToolOutputValidator(session) | ||
|
||
|
||
_default_tool_output_validator = _escape_from_circular_python_hell |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI internal-package-compatibility reasons we'll need to pin the
jsonschema
version to4.20.0