|
5 | 5 | import pytest
|
6 | 6 | from mcp import ClientSession, ListToolsResult, StdioServerParameters, Tool
|
7 | 7 |
|
8 |
| -from semantic_kernel.connectors.mcp import ( |
9 |
| - MCPSsePlugin, |
10 |
| - MCPStdioPlugin, |
11 |
| -) |
| 8 | +from semantic_kernel.connectors.mcp import MCPSsePlugin, MCPStdioPlugin, MCPWebsocketPlugin |
12 | 9 | from semantic_kernel.exceptions import KernelPluginInvalidConfigurationError
|
13 | 10 |
|
14 | 11 | if TYPE_CHECKING:
|
@@ -116,6 +113,37 @@ async def test_with_kwargs_stdio(mock_session, mock_client, list_tool_calls, ker
|
116 | 113 | assert len(loaded_plugin.functions["func2"].parameters) == 0
|
117 | 114 |
|
118 | 115 |
|
| 116 | +@patch("semantic_kernel.connectors.mcp.websocket_client") |
| 117 | +@patch("semantic_kernel.connectors.mcp.ClientSession") |
| 118 | +async def test_with_kwargs_websocket(mock_session, mock_client, list_tool_calls, kernel: "Kernel"): |
| 119 | + mock_read = MagicMock() |
| 120 | + mock_write = MagicMock() |
| 121 | + |
| 122 | + mock_generator = MagicMock() |
| 123 | + # Make the mock_stdio_client return an AsyncMock for the context manager |
| 124 | + mock_generator.__aenter__.return_value = (mock_read, mock_write) |
| 125 | + mock_generator.__aexit__.return_value = (mock_read, mock_write) |
| 126 | + |
| 127 | + # Make the mock_stdio_client return an AsyncMock for the context manager |
| 128 | + mock_client.return_value = mock_generator |
| 129 | + mock_session.return_value.__aenter__.return_value.list_tools.return_value = list_tool_calls |
| 130 | + async with MCPWebsocketPlugin( |
| 131 | + name="TestMCPPlugin", |
| 132 | + description="Test MCP Plugin", |
| 133 | + url="http://localhost:8080/websocket", |
| 134 | + ) as plugin: |
| 135 | + mock_client.assert_called_once_with(url="http://localhost:8080/websocket") |
| 136 | + loaded_plugin = kernel.add_plugin(plugin) |
| 137 | + assert loaded_plugin is not None |
| 138 | + assert loaded_plugin.name == "TestMCPPlugin" |
| 139 | + assert loaded_plugin.description == "Test MCP Plugin" |
| 140 | + assert loaded_plugin.functions.get("func1") is not None |
| 141 | + assert loaded_plugin.functions["func1"].parameters[0].name == "name" |
| 142 | + assert loaded_plugin.functions["func1"].parameters[0].is_required |
| 143 | + assert loaded_plugin.functions.get("func2") is not None |
| 144 | + assert len(loaded_plugin.functions["func2"].parameters) == 0 |
| 145 | + |
| 146 | + |
119 | 147 | @patch("semantic_kernel.connectors.mcp.sse_client")
|
120 | 148 | @patch("semantic_kernel.connectors.mcp.ClientSession")
|
121 | 149 | async def test_with_kwargs_sse(mock_session, mock_client, list_tool_calls, kernel: "Kernel"):
|
|
0 commit comments