Skip to content

Commit d3cd569

Browse files
committed
SSE and WS handlers cancel when client disconnects
Prior to this commit, Server Sent Events and WebSocket handlers for MVC would not behave properly in case the client disconnects while the response is written to. This would throw an `IOException` processed by the `BaseSubscriber` handlers, but would not actively cancel the upstream publisher. This means the publisher would keep publishing values even though it is not possible to write to the connection anymore. This commit ensures that any exception triggers a cancel signal sent to the upstream publisher to avoid such cases. Fixes gh-1060
1 parent c340d85 commit d3cd569

File tree

6 files changed

+78
-7
lines changed

6 files changed

+78
-7
lines changed

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ private void writeResult(Map<String, Object> value) {
106106
this.sseBuilder.data(value);
107107
}
108108
catch (IOException exception) {
109-
onError(exception);
109+
cancel();
110+
hookOnError(exception);
110111
}
111112
}
112113

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java

+1
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ protected void hookOnNext(TextMessage nextMessage) {
609609
request(1);
610610
}
611611
catch (IOException ex) {
612+
cancel();
612613
ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.session, ex, logger);
613614
}
614615
}

spring-graphql/src/test/java/org/springframework/graphql/server/WebSocketHandlerTestSupport.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616

1717
package org.springframework.graphql.server;
1818

19+
import java.io.IOException;
20+
import java.util.concurrent.atomic.AtomicBoolean;
21+
1922
import reactor.core.publisher.Flux;
2023

2124
import org.springframework.graphql.BookSource;
2225
import org.springframework.graphql.GraphQlSetup;
26+
import org.springframework.graphql.server.webmvc.TestWebSocketSession;
27+
import org.springframework.web.socket.WebSocketMessage;
2328

2429
public abstract class WebSocketHandlerTestSupport {
2530

@@ -29,6 +34,8 @@ public abstract class WebSocketHandlerTestSupport {
2934

3035
protected static final String BOOK_QUERY_PAYLOAD;
3136

37+
protected static AtomicBoolean SUBSCRIPTION_CANCELLED = new AtomicBoolean();
38+
3239
static {
3340
BOOK_QUERY_PAYLOAD = "{\"query\": \"" +
3441
" query TestQuery {" +
@@ -73,10 +80,19 @@ protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) {
7380
.subscriptionFetcher("bookSearch", environment -> {
7481
String author = environment.getArgument("author");
7582
return Flux.fromIterable(BookSource.books())
76-
.filter((book) -> book.getAuthor().getFullName().contains(author));
83+
.filter((book) -> book.getAuthor().getFullName().contains(author))
84+
.doOnCancel(() -> SUBSCRIPTION_CANCELLED.set(true));
7785
})
7886
.interceptor(interceptors)
7987
.toWebGraphQlHandler();
8088
}
8189

90+
public class BrokenPipeSession extends TestWebSocketSession {
91+
92+
@Override
93+
public void sendMessage(WebSocketMessage<?> message) throws IOException {
94+
throw new IOException("broken pipe");
95+
}
96+
}
97+
8298
}

spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java

+35-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
import java.nio.charset.StandardCharsets;
2222
import java.time.Duration;
2323
import java.util.List;
24+
import java.util.concurrent.atomic.AtomicBoolean;
2425

2526
import graphql.schema.DataFetcher;
2627
import jakarta.servlet.ServletException;
28+
import jakarta.servlet.ServletOutputStream;
29+
import jakarta.servlet.http.HttpServletResponse;
2730
import org.junit.jupiter.api.Test;
2831
import reactor.core.publisher.Flux;
2932

@@ -40,6 +43,10 @@
4043

4144
import static org.assertj.core.api.Assertions.assertThat;
4245
import static org.awaitility.Awaitility.await;
46+
import static org.mockito.ArgumentMatchers.any;
47+
import static org.mockito.BDDMockito.given;
48+
import static org.mockito.BDDMockito.willThrow;
49+
import static org.mockito.Mockito.mock;
4350

4451
/**
4552
* Tests for {@link GraphQlSseHandler}.
@@ -51,9 +58,13 @@ class GraphQlSseHandlerTests {
5158
private static final List<HttpMessageConverter<?>> MESSAGE_READERS =
5259
List.of(new MappingJackson2HttpMessageConverter());
5360

61+
private static final AtomicBoolean DATA_FETCHER_CANCELLED = new AtomicBoolean();
62+
5463
private static final DataFetcher<?> SEARCH_DATA_FETCHER = env -> {
5564
String author = env.getArgument("author");
56-
return Flux.fromIterable(BookSource.books()).filter((book) -> book.getAuthor().getFullName().contains(author));
65+
return Flux.fromIterable(BookSource.books())
66+
.filter((book) -> book.getAuthor().getFullName().contains(author))
67+
.doOnCancel(() -> DATA_FETCHER_CANCELLED.set(true));
5768
};
5869

5970

@@ -122,6 +133,29 @@ void shouldWriteEventsAndTerminalError() throws Exception {
122133
""");
123134
}
124135

136+
@Test
137+
void shouldCancelDataFetcherPublisherWhenWritingFails() throws Exception {
138+
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
139+
MockHttpServletRequest servletRequest = createServletRequest("""
140+
{ "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" }
141+
""");
142+
HttpServletResponse servletResponse = mock(HttpServletResponse.class);
143+
ServletOutputStream outputStream = mock(ServletOutputStream.class);
144+
145+
willThrow(new IOException("broken pipe")).given(outputStream).write(any());
146+
given(servletResponse.getOutputStream()).willReturn(outputStream);
147+
148+
ServerRequest request = ServerRequest.create(servletRequest, MESSAGE_READERS);
149+
ServerResponse response = handler.handleRequest(request);
150+
if (response instanceof AsyncServerResponse asyncResponse) {
151+
asyncResponse.block();
152+
}
153+
154+
response.writeTo(servletRequest, servletResponse, new DefaultContext());
155+
await().atMost(Duration.ofMillis(500)).until(DATA_FETCHER_CANCELLED::get);
156+
157+
}
158+
125159
private GraphQlSseHandler createSseHandler(DataFetcher<?> dataFetcher) {
126160
return new GraphQlSseHandler(GraphQlSetup.schemaResource(BookSource.schema)
127161
.queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L))

spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java

+21-3
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,10 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
7777

7878
private static final Duration TIMEOUT = Duration.ofSeconds(5);
7979

80-
81-
private final TestWebSocketSession session = new TestWebSocketSession();
82-
8380
private final GraphQlWebSocketHandler handler = initWebSocketHandler();
8481

82+
private TestWebSocketSession session = new TestWebSocketSession();
83+
8584

8685
@Test
8786
void query() throws Exception {
@@ -130,6 +129,25 @@ void subscription() throws Exception {
130129
.verify(TIMEOUT);
131130
}
132131

132+
@Test
133+
void brokenPipeShouldCancelPublisher() throws Exception {
134+
this.session = new BrokenPipeSession();
135+
handle(this.handler, new TextMessage("{\"type\":\"connection_init\"}"), new TextMessage(BOOK_SUBSCRIPTION));
136+
137+
BiConsumer<WebSocketMessage<?>, String> bookPayloadAssertion = (message, bookId) -> {
138+
GraphQlWebSocketMessage actual = decode(message);
139+
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
140+
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
141+
assertThat(actual.<Map<String, Object>>getPayload())
142+
.extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class)))
143+
.extractingByKey("bookSearch", as(InstanceOfAssertFactories.map(String.class, Object.class)))
144+
.containsEntry("id", bookId);
145+
};
146+
147+
StepVerifier.create(session.getOutput()).verifyComplete();
148+
assertThat(SUBSCRIPTION_CANCELLED).isTrue();
149+
}
150+
133151
@Test
134152
void keepAlive() throws Exception {
135153
GraphQlWebSocketHandler webSocketHandler =

spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/TestWebSocketSession.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.graphql.server.webmvc;
1818

19+
import java.io.IOException;
1920
import java.net.InetSocketAddress;
2021
import java.net.URI;
2122
import java.security.Principal;
@@ -119,7 +120,7 @@ public List<WebSocketExtension> getExtensions() {
119120
}
120121

121122
@Override
122-
public void sendMessage(WebSocketMessage<?> message) {
123+
public void sendMessage(WebSocketMessage<?> message) throws IOException {
123124
emitMessagesSignal(this.messagesSink.tryEmitNext(message));
124125
}
125126

0 commit comments

Comments
 (0)