Skip to content

Commit 80ef960

Browse files
committed
WebSocket handlers support keepalive PING messages
Closes gh-534
1 parent c8573cb commit 80ef960

File tree

5 files changed

+178
-28
lines changed

5 files changed

+178
-28
lines changed

spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java

+39-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -44,6 +44,7 @@
4444
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
4545
import org.springframework.http.HttpHeaders;
4646
import org.springframework.http.codec.CodecConfigurer;
47+
import org.springframework.lang.Nullable;
4748
import org.springframework.util.Assert;
4849
import org.springframework.util.CollectionUtils;
4950
import org.springframework.web.reactive.socket.CloseStatus;
@@ -72,10 +73,13 @@ public class GraphQlWebSocketHandler implements WebSocketHandler {
7273

7374
private final WebSocketGraphQlInterceptor webSocketInterceptor;
7475

75-
private final WebSocketCodecDelegate webSocketCodecDelegate;
76+
private final WebSocketCodecDelegate codecDelegate;
7677

7778
private final Duration initTimeoutDuration;
7879

80+
@Nullable
81+
private final Duration keepAliveDuration;
82+
7983

8084
/**
8185
* Create a new instance.
@@ -87,12 +91,30 @@ public class GraphQlWebSocketHandler implements WebSocketHandler {
8791
public GraphQlWebSocketHandler(
8892
WebGraphQlHandler graphQlHandler, CodecConfigurer codecConfigurer, Duration connectionInitTimeout) {
8993

94+
this(graphQlHandler, codecConfigurer, connectionInitTimeout, null);
95+
}
96+
97+
/**
98+
* Create a new instance.
99+
* @param graphQlHandler common handler for GraphQL over WebSocket requests
100+
* @param codecConfigurer codec configurer for JSON encoding and decoding
101+
* @param connectionInitTimeout how long to wait after the establishment of
102+
* the WebSocket for the {@code "connection_ini"} message from the client.
103+
* @param keepAliveDuration how frequently to send ping messages; if not
104+
* set then ping messages are not sent.
105+
* @since 1.3
106+
*/
107+
public GraphQlWebSocketHandler(
108+
WebGraphQlHandler graphQlHandler, CodecConfigurer codecConfigurer,
109+
Duration connectionInitTimeout, @Nullable Duration keepAliveDuration) {
110+
90111
Assert.notNull(graphQlHandler, "WebGraphQlHandler is required");
91112

92113
this.graphQlHandler = graphQlHandler;
93114
this.webSocketInterceptor = this.graphQlHandler.getWebSocketInterceptor();
94-
this.webSocketCodecDelegate = new WebSocketCodecDelegate(codecConfigurer);
115+
this.codecDelegate = new WebSocketCodecDelegate(codecConfigurer);
95116
this.initTimeoutDuration = connectionInitTimeout;
117+
this.keepAliveDuration = keepAliveDuration;
96118
}
97119

98120

@@ -137,7 +159,7 @@ public Mono<Void> handle(WebSocketSession session) {
137159
.subscribe();
138160

139161
return session.send(session.receive().flatMap((webSocketMessage) -> {
140-
GraphQlWebSocketMessage message = this.webSocketCodecDelegate.decode(webSocketMessage);
162+
GraphQlWebSocketMessage message = this.codecDelegate.decode(webSocketMessage);
141163
String id = message.getId();
142164
Map<String, Object> payload = message.getPayload();
143165
switch (message.resolvedType()) {
@@ -159,7 +181,7 @@ public Mono<Void> handle(WebSocketSession session) {
159181
.doOnTerminate(() -> subscriptions.remove(id));
160182
}
161183
case PING -> {
162-
return Flux.just(this.webSocketCodecDelegate.encode(session, GraphQlWebSocketMessage.pong(null)));
184+
return Flux.just(this.codecDelegate.encode(session, GraphQlWebSocketMessage.pong(null)));
163185
}
164186
case COMPLETE -> {
165187
if (id != null) {
@@ -176,11 +198,16 @@ public Mono<Void> handle(WebSocketSession session) {
176198
if (!connectionInitPayloadRef.compareAndSet(null, payload)) {
177199
return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
178200
}
179-
return this.webSocketInterceptor.handleConnectionInitialization(sessionInfo, payload)
201+
Flux<WebSocketMessage> flux = this.webSocketInterceptor.handleConnectionInitialization(sessionInfo, payload)
180202
.defaultIfEmpty(Collections.emptyMap())
181-
.map((ackPayload) -> this.webSocketCodecDelegate.encodeConnectionAck(session, ackPayload))
182-
.flux()
183-
.onErrorResume((ex) -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
203+
.map((ackPayload) -> this.codecDelegate.encodeConnectionAck(session, ackPayload))
204+
.flux();
205+
if (this.keepAliveDuration != null) {
206+
flux = flux.mergeWith(Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
207+
.filter((aLong) -> !this.codecDelegate.checkMessagesEncodedAndClear())
208+
.map((aLong) -> this.codecDelegate.encode(session, GraphQlWebSocketMessage.ping(null))));
209+
}
210+
return flux.onErrorResume((ex) -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
184211
}
185212
default -> {
186213
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
@@ -218,14 +245,14 @@ private Flux<WebSocketMessage> handleResponse(WebSocketSession session, String i
218245
}
219246

220247
return responseFlux
221-
.map((responseMap) -> this.webSocketCodecDelegate.encodeNext(session, id, responseMap))
222-
.concatWith(Mono.fromCallable(() -> this.webSocketCodecDelegate.encodeComplete(session, id)))
248+
.map((responseMap) -> this.codecDelegate.encodeNext(session, id, responseMap))
249+
.concatWith(Mono.fromCallable(() -> this.codecDelegate.encodeComplete(session, id)))
223250
.onErrorResume((ex) -> {
224251
if (ex instanceof SubscriptionExistsException) {
225252
CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists");
226253
return GraphQlStatus.close(session, status);
227254
}
228-
return Mono.fromCallable(() -> this.webSocketCodecDelegate.encodeError(session, id, ex));
255+
return Mono.fromCallable(() -> this.codecDelegate.encodeError(session, id, ex));
229256
});
230257
}
231258

spring-graphql/src/main/java/org/springframework/graphql/server/webflux/WebSocketCodecDelegate.java

+10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ final class WebSocketCodecDelegate {
5454

5555
private final Encoder<?> encoder;
5656

57+
private boolean messagesEncoded;
58+
5759

5860
WebSocketCodecDelegate(CodecConfigurer codecConfigurer) {
5961
Assert.notNull(codecConfigurer, "CodecConfigurer is required");
@@ -84,6 +86,8 @@ <T> WebSocketMessage encode(WebSocketSession session, GraphQlWebSocketMessage me
8486
DataBuffer buffer = ((Encoder<T>) this.encoder).encodeValue(
8587
(T) message, session.bufferFactory(), MESSAGE_TYPE, MimeTypeUtils.APPLICATION_JSON, null);
8688

89+
this.messagesEncoded = true;
90+
8791
return new WebSocketMessage(WebSocketMessage.Type.TEXT, buffer);
8892
}
8993

@@ -115,4 +119,10 @@ WebSocketMessage encodeComplete(WebSocketSession session, String id) {
115119
return encode(session, GraphQlWebSocketMessage.complete(id));
116120
}
117121

122+
boolean checkMessagesEncodedAndClear() {
123+
boolean result = this.messagesEncoded;
124+
this.messagesEncoded = false;
125+
return result;
126+
}
127+
118128
}

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java

+70-1
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,12 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub
104104

105105
private final HttpMessageConverter<?> converter;
106106

107+
@Nullable
108+
private final Duration keepAliveDuration;
109+
107110
private final Map<String, SessionState> sessionInfoMap = new ConcurrentHashMap<>();
108111

112+
109113
/**
110114
* Create a new instance.
111115
* @param graphQlHandler common handler for GraphQL over WebSocket requests
@@ -116,6 +120,23 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub
116120
public GraphQlWebSocketHandler(
117121
WebGraphQlHandler graphQlHandler, HttpMessageConverter<?> converter, Duration connectionInitTimeout) {
118122

123+
this(graphQlHandler, converter, connectionInitTimeout, null);
124+
}
125+
126+
/**
127+
* Create a new instance.
128+
* @param graphQlHandler common handler for GraphQL over WebSocket requests
129+
* @param converter for JSON encoding and decoding
130+
* @param connectionInitTimeout how long to wait after the establishment of
131+
* the WebSocket for the {@code "connection_ini"} message from the client.
132+
* @param keepAliveDuration how frequently to send ping messages; if not
133+
* set then ping messages are not sent.
134+
* @since 1.3
135+
*/
136+
public GraphQlWebSocketHandler(
137+
WebGraphQlHandler graphQlHandler, HttpMessageConverter<?> converter,
138+
Duration connectionInitTimeout, @Nullable Duration keepAliveDuration) {
139+
119140
Assert.notNull(graphQlHandler, "WebGraphQlHandler is required");
120141
Assert.notNull(converter, "HttpMessageConverter for JSON is required");
121142

@@ -124,8 +145,10 @@ public GraphQlWebSocketHandler(
124145
this.webSocketGraphQlInterceptor = this.graphQlHandler.getWebSocketInterceptor();
125146
this.initTimeoutDuration = connectionInitTimeout;
126147
this.converter = converter;
148+
this.keepAliveDuration = keepAliveDuration;
127149
}
128150

151+
129152
@Override
130153
public List<String> getSubProtocols() {
131154
return SUB_PROTOCOL_LIST;
@@ -257,6 +280,21 @@ private void handleInternal(WebSocketSession session, TextMessage webSocketMessa
257280
return Mono.empty();
258281
})
259282
.block(Duration.ofSeconds(10));
283+
284+
if (this.keepAliveDuration != null) {
285+
Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
286+
.filter((aLong) -> true)
287+
.publishOn(state.getScheduler()) // Serial blocking send via single thread
288+
.doOnNext((aLong) -> {
289+
try {
290+
session.sendMessage(encode(GraphQlWebSocketMessage.ping(null)));
291+
}
292+
catch (IOException ex) {
293+
ExceptionWebSocketHandlerDecorator.tryCloseWithError(session, ex, logger);
294+
}
295+
})
296+
.subscribe(state.getKeepAliveSubscriber());
297+
}
260298
}
261299
default -> GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
262300
}
@@ -444,9 +482,12 @@ private static class SessionState {
444482

445483
private final Scheduler scheduler;
446484

447-
SessionState(String graphQlSessionId, WebSocketSessionInfo sessionInfo) {
485+
private final KeepAliveSubscriber keepAliveSubscriber;
486+
487+
SessionState(String graphQlSessionId, WebMvcSessionInfo sessionInfo) {
448488
this.sessionInfo = sessionInfo;
449489
this.scheduler = Schedulers.newSingle("GraphQL-WsSession-" + graphQlSessionId);
490+
this.keepAliveSubscriber = new KeepAliveSubscriber(sessionInfo.getSession());
450491
}
451492

452493
WebSocketSessionInfo getSessionInfo() {
@@ -462,12 +503,16 @@ boolean setConnectionInitPayload(Map<String, Object> payload) {
462503
return this.connectionInitPayloadRef.compareAndSet(null, payload);
463504
}
464505

506+
KeepAliveSubscriber getKeepAliveSubscriber() {
507+
return this.keepAliveSubscriber;
508+
}
465509

466510
Map<String, Subscription> getSubscriptions() {
467511
return this.subscriptions;
468512
}
469513

470514
void dispose() {
515+
this.keepAliveSubscriber.cancel();
471516
for (Map.Entry<String, Subscription> entry : this.subscriptions.entrySet()) {
472517
try {
473518
entry.getValue().cancel();
@@ -525,6 +570,10 @@ public Mono<Principal> getPrincipal() {
525570
public InetSocketAddress getRemoteAddress() {
526571
return this.session.getRemoteAddress();
527572
}
573+
574+
WebSocketSession getSession() {
575+
return this.session;
576+
}
528577
}
529578

530579

@@ -567,9 +616,29 @@ public void hookOnError(Throwable ex) {
567616
public void hookOnComplete() {
568617
this.sessionState.getSubscriptions().remove(this.subscriptionId);
569618
}
619+
}
620+
621+
622+
private static class KeepAliveSubscriber extends BaseSubscriber<Long> {
623+
624+
private final WebSocketSession session;
625+
626+
KeepAliveSubscriber(WebSocketSession session) {
627+
this.session = session;
628+
}
570629

630+
@Override
631+
protected void hookOnSubscribe(Subscription subscription) {
632+
subscription.request(Integer.MAX_VALUE);
633+
}
634+
635+
@Override
636+
public void hookOnError(Throwable ex) {
637+
ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.session, ex, logger);
638+
}
571639
}
572640

641+
573642
@SuppressWarnings("serial")
574643
private static final class SubscriptionExistsException extends RuntimeException {
575644
}

0 commit comments

Comments
 (0)