Skip to content

Commit 349ebcc

Browse files
added websocket support
1 parent 88a42ca commit 349ebcc

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

python/semantic_kernel/connectors/mcp.py

+42
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mcp.client.session import ClientSession
1212
from mcp.client.sse import sse_client
1313
from mcp.client.stdio import StdioServerParameters, stdio_client
14+
from mcp.client.websocket import websocket_client
1415
from mcp.types import CallToolResult, EmbeddedResource, Prompt, PromptMessage, TextResourceContents, Tool
1516
from mcp.types import (
1617
ImageContent as MCPImageContent,
@@ -348,3 +349,44 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
348349
if self._client_kwargs:
349350
args.update(self._client_kwargs)
350351
return sse_client(**args)
352+
353+
354+
class MCPWebsocketPlugin(MCPPluginBase):
355+
"""MCP websocket server configuration."""
356+
357+
def __init__(
358+
self,
359+
name: str,
360+
url: str,
361+
session: ClientSession | None = None,
362+
description: str | None = None,
363+
**kwargs: Any,
364+
) -> None:
365+
"""Initialize the MCP websocket plugin.
366+
367+
The arguments are used to create a websocket client.
368+
see mcp.client.websocket.websocket_client for more details.
369+
370+
Any extra arguments passed to the constructor will be passed to the
371+
websocket client constructor.
372+
373+
Args:
374+
name: The name of the plugin.
375+
url: The URL of the MCP server.
376+
session: The session to use for the MCP connection.
377+
description: The description of the plugin.
378+
kwargs: Any extra arguments to pass to the websocket client.
379+
380+
"""
381+
super().__init__(name, description, session)
382+
self.url = url
383+
self._client_kwargs = kwargs
384+
385+
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
386+
"""Get an MCP websocket client."""
387+
args: dict[str, Any] = {
388+
"url": self.url,
389+
}
390+
if self._client_kwargs:
391+
args.update(self._client_kwargs)
392+
return websocket_client(**args)

python/tests/unit/connectors/mcp/test_mcp.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import pytest
66
from mcp import ClientSession, ListToolsResult, StdioServerParameters, Tool
77

8-
from semantic_kernel.connectors.mcp import (
9-
MCPSsePlugin,
10-
MCPStdioPlugin,
11-
)
8+
from semantic_kernel.connectors.mcp import MCPSsePlugin, MCPStdioPlugin, MCPWebsocketPlugin
129
from semantic_kernel.exceptions import KernelPluginInvalidConfigurationError
1310

1411
if TYPE_CHECKING:
@@ -116,6 +113,37 @@ async def test_with_kwargs_stdio(mock_session, mock_client, list_tool_calls, ker
116113
assert len(loaded_plugin.functions["func2"].parameters) == 0
117114

118115

116+
@patch("semantic_kernel.connectors.mcp.websocket_client")
117+
@patch("semantic_kernel.connectors.mcp.ClientSession")
118+
async def test_with_kwargs_websocket(mock_session, mock_client, list_tool_calls, kernel: "Kernel"):
119+
mock_read = MagicMock()
120+
mock_write = MagicMock()
121+
122+
mock_generator = MagicMock()
123+
# Make the mock_stdio_client return an AsyncMock for the context manager
124+
mock_generator.__aenter__.return_value = (mock_read, mock_write)
125+
mock_generator.__aexit__.return_value = (mock_read, mock_write)
126+
127+
# Make the mock_stdio_client return an AsyncMock for the context manager
128+
mock_client.return_value = mock_generator
129+
mock_session.return_value.__aenter__.return_value.list_tools.return_value = list_tool_calls
130+
async with MCPWebsocketPlugin(
131+
name="TestMCPPlugin",
132+
description="Test MCP Plugin",
133+
url="http://localhost:8080/websocket",
134+
) as plugin:
135+
mock_client.assert_called_once_with(url="http://localhost:8080/websocket")
136+
loaded_plugin = kernel.add_plugin(plugin)
137+
assert loaded_plugin is not None
138+
assert loaded_plugin.name == "TestMCPPlugin"
139+
assert loaded_plugin.description == "Test MCP Plugin"
140+
assert loaded_plugin.functions.get("func1") is not None
141+
assert loaded_plugin.functions["func1"].parameters[0].name == "name"
142+
assert loaded_plugin.functions["func1"].parameters[0].is_required
143+
assert loaded_plugin.functions.get("func2") is not None
144+
assert len(loaded_plugin.functions["func2"].parameters) == 0
145+
146+
119147
@patch("semantic_kernel.connectors.mcp.sse_client")
120148
@patch("semantic_kernel.connectors.mcp.ClientSession")
121149
async def test_with_kwargs_sse(mock_session, mock_client, list_tool_calls, kernel: "Kernel"):

0 commit comments

Comments
 (0)