|
| 1 | +// Copyright 2025 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package mcp |
| 6 | + |
| 7 | +import ( |
| 8 | + "context" |
| 9 | + "encoding/json" |
| 10 | + "errors" |
| 11 | + "fmt" |
| 12 | + "iter" |
| 13 | + "slices" |
| 14 | + "sync" |
| 15 | + |
| 16 | + jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" |
| 17 | + "golang.org/x/tools/internal/mcp/internal/protocol" |
| 18 | +) |
| 19 | + |
| 20 | +// A Client is an MCP client, which may be connected to one or more MCP servers |
| 21 | +// using the [Client.Connect] method. |
| 22 | +// |
| 23 | +// TODO(rfindley): revisit the many-to-one relationship of clients and servers. |
| 24 | +// It is a bit odd. |
| 25 | +type Client struct { |
| 26 | + name string |
| 27 | + version string |
| 28 | + |
| 29 | + mu sync.Mutex |
| 30 | + servers []*ServerConnection |
| 31 | +} |
| 32 | + |
| 33 | +// NewClient creates a new Client. |
| 34 | +// |
| 35 | +// Use [Client.Connect] to connect it to an MCP server. |
| 36 | +// |
| 37 | +// If non-nil, the provided options configure the Client. |
| 38 | +func NewClient(name, version string, opts *ClientOptions) *Client { |
| 39 | + return &Client{ |
| 40 | + name: name, |
| 41 | + version: version, |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +// Servers returns an iterator that yields the current set of server |
| 46 | +// connections. |
| 47 | +func (c *Client) Servers() iter.Seq[*ServerConnection] { |
| 48 | + c.mu.Lock() |
| 49 | + clients := slices.Clone(c.servers) |
| 50 | + c.mu.Unlock() |
| 51 | + return slices.Values(clients) |
| 52 | +} |
| 53 | + |
| 54 | +// ClientOptions configures the behavior of the client, and apply to every |
| 55 | +// client-server connection created using [Client.Connect]. |
| 56 | +type ClientOptions struct{} |
| 57 | + |
| 58 | +// bind implements the binder[*ServerConnection] interface, so that Clients can |
| 59 | +// be connected using [connect]. |
| 60 | +func (c *Client) bind(conn *jsonrpc2.Connection) *ServerConnection { |
| 61 | + sc := &ServerConnection{ |
| 62 | + conn: conn, |
| 63 | + client: c, |
| 64 | + } |
| 65 | + c.mu.Lock() |
| 66 | + c.servers = append(c.servers, sc) |
| 67 | + c.mu.Unlock() |
| 68 | + return sc |
| 69 | +} |
| 70 | + |
| 71 | +// disconnect implements the binder[*ServerConnection] interface, so that |
| 72 | +// Clients can be connected using [connect]. |
| 73 | +func (c *Client) disconnect(sc *ServerConnection) { |
| 74 | + c.mu.Lock() |
| 75 | + defer c.mu.Unlock() |
| 76 | + c.servers = slices.DeleteFunc(c.servers, func(sc2 *ServerConnection) bool { |
| 77 | + return sc2 == sc |
| 78 | + }) |
| 79 | +} |
| 80 | + |
| 81 | +// Connect connects the MCP client over the given transport and initializes an |
| 82 | +// MCP session. |
| 83 | +// |
| 84 | +// It returns a connection object that may be used to query the MCP server, |
| 85 | +// terminate the connection (with [Connection.Close]), or await server |
| 86 | +// termination (with [Connection.Wait]). |
| 87 | +// |
| 88 | +// Typically, it is the responsibility of the client to close the connection |
| 89 | +// when it is no longer needed. However, if the connection is closed by the |
| 90 | +// server, calls or notifications will return an error wrapping |
| 91 | +// [ErrConnectionClosed]. |
| 92 | +func (c *Client) Connect(ctx context.Context, t *Transport, opts *ConnectionOptions) (sc *ServerConnection, err error) { |
| 93 | + defer func() { |
| 94 | + if sc != nil && err != nil { |
| 95 | + _ = sc.Close() |
| 96 | + } |
| 97 | + }() |
| 98 | + sc, err = connect(ctx, t, opts, c) |
| 99 | + if err != nil { |
| 100 | + return nil, err |
| 101 | + } |
| 102 | + params := &protocol.InitializeParams{ |
| 103 | + ClientInfo: protocol.Implementation{Name: c.name, Version: c.version}, |
| 104 | + } |
| 105 | + if err := call(ctx, sc.conn, "initialize", params, &sc.initializeResult); err != nil { |
| 106 | + return nil, err |
| 107 | + } |
| 108 | + if err := sc.conn.Notify(ctx, "initialized", &protocol.InitializedParams{}); err != nil { |
| 109 | + return nil, err |
| 110 | + } |
| 111 | + return sc, nil |
| 112 | +} |
| 113 | + |
| 114 | +// A ServerConnection is a connection with an MCP server. |
| 115 | +// |
| 116 | +// It handles messages from the client, and can be used to send messages to the |
| 117 | +// client. Create a connection by calling [Server.Connect]. |
| 118 | +type ServerConnection struct { |
| 119 | + conn *jsonrpc2.Connection |
| 120 | + client *Client |
| 121 | + initializeResult *protocol.InitializeResult |
| 122 | +} |
| 123 | + |
| 124 | +// Close performs a graceful close of the connection, preventing new requests |
| 125 | +// from being handled, and waiting for ongoing requests to return. Close then |
| 126 | +// terminates the connection. |
| 127 | +func (cc *ServerConnection) Close() error { |
| 128 | + return cc.conn.Close() |
| 129 | +} |
| 130 | + |
| 131 | +// Wait waits for the connection to be closed by the server. |
| 132 | +// Generally, clients should be responsible for closing the connection. |
| 133 | +func (cc *ServerConnection) Wait() error { |
| 134 | + return cc.conn.Wait() |
| 135 | +} |
| 136 | + |
| 137 | +func (sc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) { |
| 138 | + switch req.Method { |
| 139 | + } |
| 140 | + return nil, jsonrpc2.ErrNotHandled |
| 141 | +} |
| 142 | + |
| 143 | +// ListTools lists tools that are currently available on the server. |
| 144 | +func (sc *ServerConnection) ListTools(ctx context.Context) ([]protocol.Tool, error) { |
| 145 | + var ( |
| 146 | + params = &protocol.ListToolsParams{} |
| 147 | + result protocol.ListToolsResult |
| 148 | + ) |
| 149 | + if err := call(ctx, sc.conn, "tools/list", params, &result); err != nil { |
| 150 | + return nil, err |
| 151 | + } |
| 152 | + return result.Tools, nil |
| 153 | +} |
| 154 | + |
| 155 | +// CallTool calls the tool with the given name and arguments. |
| 156 | +// |
| 157 | +// TODO: make the following true: |
| 158 | +// If the provided arguments do not conform to the schema for the given tool, |
| 159 | +// the call fails. |
| 160 | +func (sc *ServerConnection) CallTool(ctx context.Context, name string, args any) (_ []Content, err error) { |
| 161 | + defer func() { |
| 162 | + if err != nil { |
| 163 | + err = fmt.Errorf("calling tool %q: %w", name, err) |
| 164 | + } |
| 165 | + }() |
| 166 | + argJSON, err := json.Marshal(args) |
| 167 | + if err != nil { |
| 168 | + return nil, fmt.Errorf("marshaling args: %v", err) |
| 169 | + } |
| 170 | + var ( |
| 171 | + params = &protocol.CallToolParams{ |
| 172 | + Name: name, |
| 173 | + Arguments: argJSON, |
| 174 | + } |
| 175 | + result protocol.CallToolResult |
| 176 | + ) |
| 177 | + if err := call(ctx, sc.conn, "tools/call", params, &result); err != nil { |
| 178 | + return nil, err |
| 179 | + } |
| 180 | + content, err := unmarshalContent(result.Content) |
| 181 | + if err != nil { |
| 182 | + return nil, fmt.Errorf("unmarshaling tool content: %v", err) |
| 183 | + } |
| 184 | + if result.IsError { |
| 185 | + if len(content) != 1 || !is[TextContent](content[0]) { |
| 186 | + return nil, errors.New("malformed error content") |
| 187 | + } |
| 188 | + return nil, errors.New(content[0].(TextContent).Text) |
| 189 | + } |
| 190 | + return content, nil |
| 191 | +} |
0 commit comments