Skip to content

Commit 3f5fc1a

Browse files
toby200rstoyanchev
authored andcommitted
Support keepAlive in WebSocket client
See gh-608
1 parent d31b94e commit 3f5fc1a

File tree

6 files changed

+121
-10
lines changed

6 files changed

+121
-10
lines changed

spring-graphql-docs/modules/ROOT/pages/client.adoc

+7
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ existing `WebSocketGraphQlClient` to create a new instance with customized setti
165165
166166
----
167167

168+
If you'd like the client to send regular graphql ping messages to the server, you can add these by adding `keepalive(long seconds)` to the builder
169+
[source,java,indent=0,subs="verbatim,quotes"]
170+
----
171+
WebSocketGraphQlClient graphQlClient = WebSocketGraphQlClient.builder(url, client)
172+
.keepalive(30)
173+
.build();
174+
----
168175

169176
[[client.websocketgraphqlclient.interceptor]]
170177
==== Interceptor

spring-graphql/src/main/java/org/springframework/graphql/client/DefaultWebSocketGraphQlClientBuilder.java

+24-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ final class DefaultWebSocketGraphQlClientBuilder
4949

5050
private final CodecConfigurer codecConfigurer;
5151

52+
private long keepalive;
5253

5354
/**
5455
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient)}.
@@ -57,13 +58,28 @@ final class DefaultWebSocketGraphQlClientBuilder
5758
this(toURI(url), client);
5859
}
5960

61+
/**
62+
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient, long)}.
63+
*/
64+
DefaultWebSocketGraphQlClientBuilder(String url, WebSocketClient client, long keepalive) {
65+
this(toURI(url), client, keepalive);
66+
}
67+
6068
/**
6169
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient)}.
6270
*/
6371
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client) {
72+
this(url, client, 0);
73+
}
74+
75+
/**
76+
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient, long)}.
77+
*/
78+
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client, long keepalive) {
6479
this.url = url;
6580
this.webSocketClient = client;
6681
this.codecConfigurer = ClientCodecConfigurer.create();
82+
this.keepalive = keepalive;
6783
}
6884

6985
/**
@@ -75,6 +91,7 @@ final class DefaultWebSocketGraphQlClientBuilder
7591
this.headers.putAll(transport.getHeaders());
7692
this.webSocketClient = transport.getWebSocketClient();
7793
this.codecConfigurer = transport.getCodecConfigurer();
94+
this.keepalive = transport.getKeepAlive();
7895
}
7996

8097

@@ -119,12 +136,18 @@ public WebSocketGraphQlClient build() {
119136
CodecDelegate.findJsonDecoder(this.codecConfigurer));
120137

121138
WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
122-
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor());
139+
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepalive);
123140

124141
GraphQlClient graphQlClient = super.buildGraphQlClient(transport);
125142
return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer());
126143
}
127144

145+
@Override
146+
public WebSocketGraphQlClient.Builder<DefaultWebSocketGraphQlClientBuilder> keepalive(long keepalive) {
147+
this.keepalive = keepalive;
148+
return this;
149+
}
150+
128151
private WebSocketGraphQlClientInterceptor getInterceptor() {
129152

130153
List<WebSocketGraphQlClientInterceptor> interceptors = getInterceptors().stream()

spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlClient.java

+32
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient) {
6464
return builder(url, webSocketClient).build();
6565
}
6666

67+
/**
68+
* Create a {@link WebSocketGraphQlClient}.
69+
* @param url the GraphQL endpoint URL
70+
* @param webSocketClient the underlying transport client to use
71+
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
72+
*/
73+
static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient, long keepalive) {
74+
return builder(url, webSocketClient).keepalive(keepalive).build();
75+
}
76+
6777
/**
6878
* Return a builder for a {@link WebSocketGraphQlClient}.
6979
* @param url the GraphQL endpoint URL
@@ -73,6 +83,16 @@ static Builder<?> builder(String url, WebSocketClient webSocketClient) {
7383
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
7484
}
7585

86+
/**
87+
* Return a builder for a {@link WebSocketGraphQlClient}.
88+
* @param url the GraphQL endpoint URL
89+
* @param webSocketClient the underlying transport client to use
90+
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
91+
*/
92+
static Builder<?> builder(String url, WebSocketClient webSocketClient, long keepalive) {
93+
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
94+
}
95+
7696
/**
7797
* Return a builder for a {@link WebSocketGraphQlClient}.
7898
* @param url the GraphQL endpoint URL
@@ -82,6 +102,16 @@ static Builder<?> builder(URI url, WebSocketClient webSocketClient) {
82102
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
83103
}
84104

105+
/**
106+
* Return a builder for a {@link WebSocketGraphQlClient}.
107+
* @param url the GraphQL endpoint URL
108+
* @param webSocketClient the underlying transport client to use
109+
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
110+
*/
111+
static Builder<?> builder(URI url, WebSocketClient webSocketClient, long keepalive) {
112+
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
113+
}
114+
85115

86116
/**
87117
* Builder for a GraphQL over WebSocket client.
@@ -95,6 +125,8 @@ interface Builder<B extends Builder<B>> extends WebGraphQlClient.Builder<B> {
95125
@Override
96126
WebSocketGraphQlClient build();
97127

128+
Builder<B> keepalive(long keepalive);
129+
98130
}
99131

100132
}

spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java

+36-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.graphql.client;
1818

1919
import java.net.URI;
20+
import java.time.Duration;
2021
import java.util.Collections;
2122
import java.util.List;
2223
import java.util.Map;
@@ -67,10 +68,12 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
6768

6869
private final Mono<GraphQlSession> graphQlSessionMono;
6970

71+
private final long keepalive;
72+
7073

7174
WebSocketGraphQlTransport(
7275
URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer,
73-
WebSocketGraphQlClientInterceptor interceptor) {
76+
WebSocketGraphQlClientInterceptor interceptor, long keepalive) {
7477

7578
Assert.notNull(url, "URI is required");
7679
Assert.notNull(client, "WebSocketClient is required");
@@ -80,8 +83,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
8083
this.url = url;
8184
this.headers.putAll((headers != null) ? headers : HttpHeaders.EMPTY);
8285
this.webSocketClient = client;
86+
this.keepalive = keepalive;
8387

84-
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor);
88+
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepalive);
8589

8690
this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler)
8791
.cacheInvalidateWhen(GraphQlSession::notifyWhenClosed);
@@ -162,6 +166,10 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
162166
return this.graphQlSessionMono.flatMapMany((session) -> session.executeSubscription(request));
163167
}
164168

169+
public long getKeepAlive() {
170+
return keepalive;
171+
}
172+
165173

166174
/**
167175
* Client {@code WebSocketHandler} for GraphQL that deals with WebSocket
@@ -183,11 +191,15 @@ private static class GraphQlSessionHandler implements WebSocketHandler {
183191

184192
private final AtomicBoolean stopped = new AtomicBoolean();
185193

194+
private final long keepalive;
186195

187-
GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor) {
196+
197+
GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor,
198+
long keepalive) {
188199
this.codecDelegate = new CodecDelegate(codecConfigurer);
189200
this.interceptor = interceptor;
190201
this.graphQlSessionSink = Sinks.unsafe().one();
202+
this.keepalive = keepalive;
191203
}
192204

193205

@@ -245,7 +257,7 @@ public Mono<Void> handle(WebSocketSession session) {
245257
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
246258
.map((message) -> this.codecDelegate.encode(session, message)));
247259

248-
Mono<Void> receiveCompletion = session.receive()
260+
Flux<Void> receiveCompletion = session.receive()
249261
.flatMap((webSocketMessage) -> {
250262
if (sessionNotInitialized()) {
251263
try {
@@ -276,6 +288,7 @@ public Mono<Void> handle(WebSocketSession session) {
276288
switch (message.resolvedType()) {
277289
case NEXT -> graphQlSession.handleNext(message);
278290
case PING -> graphQlSession.sendPong(null);
291+
case PONG -> { }
279292
case ERROR -> graphQlSession.handleError(message);
280293
case COMPLETE -> graphQlSession.handleComplete(message);
281294
default -> throw new IllegalStateException(
@@ -290,10 +303,21 @@ public Mono<Void> handle(WebSocketSession session) {
290303
}
291304
}
292305
return Mono.empty();
293-
})
294-
.then();
306+
});
307+
308+
if (keepalive > 0) {
309+
Duration keepAliveDuration = Duration.ofSeconds(keepalive);
310+
receiveCompletion = receiveCompletion
311+
.mergeWith(Flux.interval(keepAliveDuration, keepAliveDuration)
312+
.flatMap(i -> {
313+
graphQlSession.sendPing(null);
314+
return Mono.empty();
315+
})
316+
);
317+
}
318+
295319

296-
return Mono.zip(sendCompletion, receiveCompletion).then();
320+
return Mono.zip(sendCompletion, receiveCompletion.then()).then();
297321
}
298322

299323
private boolean sessionNotInitialized() {
@@ -459,6 +483,11 @@ void sendPong(@Nullable Map<String, Object> payload) {
459483
this.requestSink.sendRequest(message);
460484
}
461485

486+
public void sendPing(@Nullable Map<String, Object> payload) {
487+
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(payload);
488+
this.requestSink.sendRequest(message);
489+
}
490+
462491

463492
// Inbound messages
464493

spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java

+3
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ private Publisher<GraphQlWebSocketMessage> handleMessage(GraphQlWebSocketMessage
110110
GraphQlWebSocketMessage.error(id, Collections.singletonList(request.getError())) :
111111
GraphQlWebSocketMessage.complete(id));
112112
}
113+
case PING -> {
114+
return Mono.just(GraphQlWebSocketMessage.pong(null));
115+
}
113116
case COMPLETE -> {
114117
return Flux.empty();
115118
}

spring-graphql/src/test/java/org/springframework/graphql/client/WebSocketGraphQlTransportTests.java

+19-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public class WebSocketGraphQlTransportTests {
6060
private static final Duration TIMEOUT = Duration.ofSeconds(5);
6161

6262
private static final CodecDelegate CODEC_DELEGATE = new CodecDelegate(ClientCodecConfigurer.create());
63+
public static final int KEEPALIVE = 1;
6364

6465

6566
private final MockGraphQlWebSocketServer mockServer = new MockGraphQlWebSocketServer();
@@ -185,6 +186,22 @@ void pingHandling() {
185186
GraphQlWebSocketMessage.subscribe("1", new DefaultGraphQlRequest("{Query1}")));
186187
}
187188

189+
@Test
190+
void pingSending() throws InterruptedException {
191+
192+
GraphQlRequest request = this.mockServer.expectOperation("{Sub1}").andStream(Flux.just(this.response1, response2));
193+
194+
StepVerifier.create(this.transport.executeSubscription(request))
195+
.expectNext(this.response1, response2).expectComplete()
196+
.verify(TIMEOUT);
197+
Thread.sleep(KEEPALIVE*1000 + 50); // wait for ping
198+
199+
assertActualClientMessages(
200+
GraphQlWebSocketMessage.connectionInit(null),
201+
GraphQlWebSocketMessage.subscribe("1", request),
202+
GraphQlWebSocketMessage.ping(null));
203+
}
204+
188205
@Test
189206
void start() {
190207
MockGraphQlWebSocketServer handler = new MockGraphQlWebSocketServer();
@@ -210,7 +227,7 @@ public Mono<Void> handleConnectionAck(Map<String, Object> ackPayload) {
210227

211228

212229
WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
213-
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor);
230+
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor, KEEPALIVE);
214231

215232
transport.start().block(TIMEOUT);
216233

@@ -324,7 +341,7 @@ void errorDuringResponseHandling() {
324341
private static WebSocketGraphQlTransport createTransport(WebSocketClient client) {
325342
return new WebSocketGraphQlTransport(
326343
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(),
327-
new WebSocketGraphQlClientInterceptor() { });
344+
new WebSocketGraphQlClientInterceptor() { }, KEEPALIVE);
328345
}
329346

330347
private void assertActualClientMessages(GraphQlWebSocketMessage... expectedMessages) {

0 commit comments

Comments
 (0)