Skip to content

Commit 357c3da

Browse files
committed
internal/mcp: add batching support
Update the jsonrpc2 MCP framer to support both incoming and outgoing batches. In order to achieve this, we must correlate the framed Reader and Writer, which is not explicitly supported by the jsonrpc2 API, but does work. A note is left to revisit the Framer interface. Change-Id: If060cb5822da067833db20d58f5f112a0528da91 Reviewed-on: https://go-review.googlesource.com/c/tools/+/667295 Reviewed-by: Jonathan Amsterdam <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent cc6bc88 commit 357c3da

File tree

5 files changed

+324
-12
lines changed

5 files changed

+324
-12
lines changed

internal/jsonrpc2_v2/frame.go

+9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ type Writer interface {
3939
// Framer wraps low level byte readers and writers into jsonrpc2 message
4040
// readers and writers.
4141
// It is responsible for the framing and encoding of messages into wire form.
42+
//
43+
// TODO(rfindley): rethink the framer interface, as with JSONRPC2 batching
44+
// there is a need for Reader and Writer to be correlated, and while the
45+
// implementation of framing here allows that, it is not made explicit by the
46+
// interface.
47+
//
48+
// Perhaps a better interface would be
49+
//
50+
// Frame(io.ReadWriteCloser) (Reader, Writer).
4251
type Framer interface {
4352
// Reader wraps a byte reader into a message reader.
4453
Reader(io.Reader) Reader

internal/mcp/mcp.go

-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,4 @@
1919
// - Support multiple versions of the spec.
2020
// - Implement proper JSON schema support, with both client-side and
2121
// server-side validation..
22-
// - Support batched JSON messages.
2322
package mcp

internal/mcp/mcp_test.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ func TestServerClosing(t *testing.T) {
142142
if err != nil {
143143
t.Fatal(err)
144144
}
145+
145146
var wg sync.WaitGroup
146147
wg.Add(1)
147148
go func() {
@@ -155,7 +156,36 @@ func TestServerClosing(t *testing.T) {
155156
}
156157
cc.Close()
157158
wg.Wait()
158-
if _, err = sc.CallTool(ctx, "greet", hiParams{"user"}); !errors.Is(err, mcp.ErrConnectionClosed) {
159+
if _, err := sc.CallTool(ctx, "greet", hiParams{"user"}); !errors.Is(err, mcp.ErrConnectionClosed) {
159160
t.Errorf("after disconnection, got error %v, want EOF", err)
160161
}
161162
}
163+
164+
func TestBatching(t *testing.T) {
165+
ctx := context.Background()
166+
ct, st := mcp.NewLocalTransport()
167+
168+
s := mcp.NewServer("testServer", "v1.0.0", nil)
169+
_, err := s.Connect(ctx, st, nil)
170+
if err != nil {
171+
t.Fatal(err)
172+
}
173+
174+
c := mcp.NewClient("testClient", "v1.0.0", nil)
175+
opts := new(mcp.ConnectionOptions)
176+
mcp.BatchSize(opts, 2)
177+
sc, err := c.Connect(ctx, ct, opts)
178+
if err != nil {
179+
t.Fatal(err)
180+
}
181+
defer sc.Close()
182+
183+
errs := make(chan error, 2)
184+
for range 2 {
185+
go func() {
186+
_, err := sc.ListTools(ctx)
187+
errs <- err
188+
}()
189+
}
190+
191+
}

internal/mcp/transport.go

+223-10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"log"
1414
"net"
1515
"os"
16+
"sync"
1617

1718
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
1819
)
@@ -33,7 +34,9 @@ type Transport struct {
3334
// ConnectionOptions configures the behavior of an individual client<->server
3435
// connection.
3536
type ConnectionOptions struct {
36-
Logger io.Writer
37+
Logger io.Writer // if set, write RPC logs
38+
39+
batchSize int // outgoing batch size for requests/notifications, for testing
3740
}
3841

3942
// NewStdIOTransport constructs a transport that communicates over
@@ -215,40 +218,238 @@ func (r rwc) Close() error {
215218
}
216219

217220
// A ndjsonFramer is a jsonrpc2.Framer that delimits messages with newlines.
221+
// It also supports jsonrpc2 batching.
222+
//
223+
// See https://github.com/ndjson/ndjson-spec for discussion of newline
224+
// delimited JSON.
218225
//
219-
// See also https://github.com/ndjson/ndjson-spec.
220-
type ndjsonFramer struct{}
221-
type rawReader struct{ in *json.Decoder } // relies on json.Decoder message boundaries
222-
type ndjsonWriter struct{ out io.Writer } // writes newline message boundaries
226+
// See [msgBatch] for more discussion of message batching.
227+
type ndjsonFramer struct {
228+
// batchSize allows customizing batching behavior for testing.
229+
//
230+
// If set to a positive number, requests and notifications will be buffered
231+
// into groups of this size before being sent as a batch.
232+
batchSize int
233+
234+
// batches correlate incoming requests to the batch in which they arrived.
235+
batchMu sync.Mutex
236+
batches map[jsonrpc2.ID]*msgBatch // lazily allocated
237+
}
223238

224-
func (ndjsonFramer) Reader(rw io.Reader) jsonrpc2.Reader {
225-
return &rawReader{in: json.NewDecoder(rw)}
239+
// addBatch records a msgBatch for an incoming batch payload.
240+
// It returns an error if batch is malformed, containing previously seen IDs.
241+
//
242+
// See [msgBatch] for more.
243+
func (f *ndjsonFramer) addBatch(batch *msgBatch) error {
244+
f.batchMu.Lock()
245+
defer f.batchMu.Unlock()
246+
for id := range batch.unresolved {
247+
if _, ok := f.batches[id]; ok {
248+
return fmt.Errorf("%w: batch contains previously seen request %v", jsonrpc2.ErrInvalidRequest, id.Raw())
249+
}
250+
}
251+
for id := range batch.unresolved {
252+
if f.batches == nil {
253+
f.batches = make(map[jsonrpc2.ID]*msgBatch)
254+
}
255+
f.batches[id] = batch
256+
}
257+
return nil
226258
}
227259

228-
func (ndjsonFramer) Writer(rw io.Writer) jsonrpc2.Writer {
229-
return &ndjsonWriter{out: rw}
260+
// updateBatch records a response in the message batch tracking the
261+
// corresponding incoming call, if any.
262+
//
263+
// The second result reports whether resp was part of a batch. If this is true,
264+
// the first result is nil if the batch is still incomplete, or the full set of
265+
// batch responses if resp completed the batch.
266+
func (f *ndjsonFramer) updateBatch(resp *jsonrpc2.Response) ([]*jsonrpc2.Response, bool) {
267+
f.batchMu.Lock()
268+
defer f.batchMu.Unlock()
269+
270+
if batch, ok := f.batches[resp.ID]; ok {
271+
idx, ok := batch.unresolved[resp.ID]
272+
if !ok {
273+
panic("internal error: inconsistent batches")
274+
}
275+
batch.responses[idx] = resp
276+
delete(batch.unresolved, resp.ID)
277+
delete(f.batches, resp.ID)
278+
if len(batch.unresolved) == 0 {
279+
return batch.responses, true
280+
}
281+
return nil, true
282+
}
283+
return nil, false
230284
}
231285

232-
func (r *rawReader) Read(ctx context.Context) (jsonrpc2.Message, int64, error) {
286+
// A msgBatch records information about an incoming batch of JSONRPC2 calls.
287+
//
288+
// The JSONRPC2 spec (https://www.jsonrpc.org/specification#batch) says:
289+
//
290+
// "The Server should respond with an Array containing the corresponding
291+
// Response objects, after all of the batch Request objects have been
292+
// processed. A Response object SHOULD exist for each Request object, except
293+
// that there SHOULD NOT be any Response objects for notifications. The Server
294+
// MAY process a batch rpc call as a set of concurrent tasks, processing them
295+
// in any order and with any width of parallelism."
296+
//
297+
// Therefore, a msgBatch keeps track of outstanding calls and their responses.
298+
// When there are no unresolved calls, the response payload is sent.
299+
type msgBatch struct {
300+
unresolved map[jsonrpc2.ID]int
301+
responses []*jsonrpc2.Response
302+
}
303+
304+
// An ndjsonReader reads newline-delimited messages or message batches.
305+
type ndjsonReader struct {
306+
queue []jsonrpc2.Message
307+
framer *ndjsonFramer
308+
in *json.Decoder
309+
}
310+
311+
// A ndjsonWriter writes newline-delimited messages to the wrapped io.Writer.
312+
//
313+
// If batch is set, messages are wrapped in a JSONRPC2 batch.
314+
type ndjsonWriter struct {
315+
// Testing support: if outgoingBatch has capacity, it is used to buffer
316+
// outgoing messages before sending a JSONRPC2 message batch.
317+
outgoingBatch []jsonrpc2.Message
318+
319+
framer *ndjsonFramer // to track batch responses
320+
out io.Writer // to write to the wire
321+
}
322+
323+
func (f *ndjsonFramer) Reader(r io.Reader) jsonrpc2.Reader {
324+
return &ndjsonReader{framer: f, in: json.NewDecoder(r)}
325+
}
326+
327+
func (f *ndjsonFramer) Writer(w io.Writer) jsonrpc2.Writer {
328+
writer := &ndjsonWriter{framer: f, out: w}
329+
if f.batchSize > 0 {
330+
writer.outgoingBatch = make([]jsonrpc2.Message, 0, f.batchSize)
331+
}
332+
return writer
333+
}
334+
335+
func (r *ndjsonReader) Read(ctx context.Context) (jsonrpc2.Message, int64, error) {
233336
select {
234337
case <-ctx.Done():
235338
return nil, 0, ctx.Err()
236339
default:
237340
}
341+
if len(r.queue) > 0 {
342+
next := r.queue[0]
343+
r.queue = r.queue[1:]
344+
return next, 0, nil
345+
}
238346
var raw json.RawMessage
239347
if err := r.in.Decode(&raw); err != nil {
240348
return nil, 0, err
241349
}
350+
var rawBatch []json.RawMessage
351+
if err := json.Unmarshal(raw, &rawBatch); err == nil {
352+
msg, err := r.readBatch(rawBatch)
353+
if err != nil {
354+
return nil, 0, err
355+
}
356+
return msg, int64(len(raw)), nil
357+
}
242358
msg, err := jsonrpc2.DecodeMessage(raw)
243359
return msg, int64(len(raw)), err
244360
}
245361

362+
// readBatch reads a batch of jsonrpc2 messages, and records the batch
363+
// in the framer so that responses can be collected and send back together.
364+
func (r *ndjsonReader) readBatch(rawBatch []json.RawMessage) (jsonrpc2.Message, error) {
365+
if len(rawBatch) == 0 {
366+
return nil, fmt.Errorf("empty batch")
367+
}
368+
369+
// From the spec:
370+
// "If the batch rpc call itself fails to be recognized as an valid JSON or
371+
// as an Array with at least one value, the response from the Server MUST be
372+
// a single Response object. If there are no Response objects contained
373+
// within the Response array as it is to be sent to the client, the server
374+
// MUST NOT return an empty Array and should return nothing at all."
375+
//
376+
// In our case, an error actually breaks the jsonrpc2 connection entirely,
377+
// but defensively we collect batch information before recording it, so that
378+
// we don't leave the framer in an inconsistent state.
379+
var (
380+
first jsonrpc2.Message // first message, to return
381+
queue []jsonrpc2.Message // remaining messages
382+
respBatch *msgBatch // tracks incoming requests in the batch
383+
)
384+
for i, raw := range rawBatch {
385+
msg, err := jsonrpc2.DecodeMessage(raw)
386+
if err != nil {
387+
return nil, err
388+
}
389+
if i == 0 {
390+
first = msg
391+
} else {
392+
queue = append(queue, msg)
393+
}
394+
if req, ok := msg.(*jsonrpc2.Request); ok {
395+
if respBatch == nil {
396+
respBatch = &msgBatch{
397+
unresolved: make(map[jsonrpc2.ID]int),
398+
}
399+
}
400+
respBatch.unresolved[req.ID] = len(respBatch.responses)
401+
respBatch.responses = append(respBatch.responses, nil)
402+
}
403+
}
404+
if respBatch != nil {
405+
// The batch contains one or more incoming requests to track.
406+
if err := r.framer.addBatch(respBatch); err != nil {
407+
return nil, err
408+
}
409+
}
410+
411+
r.queue = append(r.queue, queue...)
412+
return first, nil
413+
}
414+
246415
func (w *ndjsonWriter) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) {
247416
select {
248417
case <-ctx.Done():
249418
return 0, ctx.Err()
250419
default:
251420
}
421+
422+
// Batching support: if msg is a Response, it may have completed a batch, so
423+
// check that first. Otherwise, it is a request or notification, and we may
424+
// want to collect it into a batch before sending, if we're configured to use
425+
// outgoing batches.
426+
if resp, ok := msg.(*jsonrpc2.Response); ok {
427+
if batch, ok := w.framer.updateBatch(resp); ok {
428+
if len(batch) > 0 {
429+
data, err := marshalMessages(batch)
430+
if err != nil {
431+
return 0, err
432+
}
433+
data = append(data, '\n')
434+
n, err := w.out.Write(data)
435+
return int64(n), err
436+
}
437+
return 0, nil
438+
}
439+
} else if len(w.outgoingBatch) < cap(w.outgoingBatch) {
440+
w.outgoingBatch = append(w.outgoingBatch, msg)
441+
if len(w.outgoingBatch) == cap(w.outgoingBatch) {
442+
data, err := marshalMessages(w.outgoingBatch)
443+
w.outgoingBatch = w.outgoingBatch[:0]
444+
if err != nil {
445+
return 0, err
446+
}
447+
data = append(data, '\n')
448+
n, err := w.out.Write(data)
449+
return int64(n), err
450+
}
451+
return 0, nil
452+
}
252453
data, err := jsonrpc2.EncodeMessage(msg)
253454
if err != nil {
254455
return 0, fmt.Errorf("marshaling message: %v", err)
@@ -257,3 +458,15 @@ func (w *ndjsonWriter) Write(ctx context.Context, msg jsonrpc2.Message) (int64,
257458
n, err := w.out.Write(data)
258459
return int64(n), err
259460
}
461+
462+
func marshalMessages[T jsonrpc2.Message](msgs []T) ([]byte, error) {
463+
var rawMsgs []json.RawMessage
464+
for _, msg := range msgs {
465+
raw, err := jsonrpc2.EncodeMessage(msg)
466+
if err != nil {
467+
return nil, fmt.Errorf("encoding batch message: %w", err)
468+
}
469+
rawMsgs = append(rawMsgs, raw)
470+
}
471+
return json.Marshal(rawMsgs)
472+
}

0 commit comments

Comments
 (0)