Skip to content

Commit 5c2164c

Browse files
committed
Propagate cancel signal from transport to data fetchers
Prior to this commit, a WebSocket/SSE client disconnecting from the data stream would cause a CANCEL signal to be sent to upstream publishers. This signal would flow from the transport layer up to the `ExecutionGraphQlService`. Because the `GraphQL` engine itself relies on `CompletableFuture`, the CANCEL signal would not flow through and reactive data fetchers would not receive it. This means that costly reactive operations would not be cancelled and this could cause write failures as publishers would still produce values. This commit adds at the service level a Reactor `Sink` to the `GraphQLContext` that can be picked up by the `ContextDataFetcherDecorator` when decorating reactive data fetchers. This allows us to manually cancel publishers when the CANCEL signal is received at the transport level. Fixes gh-1149
1 parent 9294300 commit 5c2164c

File tree

5 files changed

+111
-7
lines changed

5 files changed

+111
-7
lines changed

spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java

+6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
*/
3737
public interface ExecutionGraphQlRequest extends GraphQlRequest {
3838

39+
/**
40+
* Key of the GraphQL context entry that holds a {@code Mono<Void>} that completes
41+
* when the inbound GraphQL request is cancelled at the transport level.
42+
*/
43+
String CANCEL_PUBLISHER_CONTEXT_KEY = ExecutionGraphQlRequest.class.getName() + ".cancelled";
44+
3945
/**
4046
* Return the transport assigned id for the request that in turn sets
4147
* {@link ExecutionInput.Builder#executionId(ExecutionId) executionId}.

spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,6 +38,7 @@
3838
import reactor.core.publisher.Flux;
3939
import reactor.core.publisher.Mono;
4040

41+
import org.springframework.graphql.ExecutionGraphQlRequest;
4142
import org.springframework.util.Assert;
4243

4344
/**
@@ -79,29 +80,35 @@ public Object get(DataFetchingEnvironment env) throws Exception {
7980

8081
GraphQLContext graphQlContext = env.getGraphQlContext();
8182
ContextSnapshotFactory snapshotFactory = ContextSnapshotFactoryHelper.getInstance(graphQlContext);
82-
8383
ContextSnapshot snapshot = (env.getLocalContext() instanceof GraphQLContext localContext) ?
8484
snapshotFactory.captureFrom(graphQlContext, localContext) :
8585
snapshotFactory.captureFrom(graphQlContext);
86+
Mono<Void> cancelledRequest = graphQlContext.get(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY);
8687

8788
Object value = snapshot.wrap(() -> this.delegate.get(env)).call();
8889

8990
if (this.subscription) {
90-
return ReactiveAdapterRegistryHelper.toSubscriptionFlux(value)
91+
Flux<?> subscriptionResult = ReactiveAdapterRegistryHelper.toSubscriptionFlux(value)
9192
.onErrorResume((exception) -> {
9293
// Already handled, e.g. controller methods?
9394
if (exception instanceof SubscriptionPublisherException) {
9495
return Mono.error(exception);
9596
}
9697
return this.subscriptionExceptionResolver.resolveException(exception)
9798
.flatMap((errors) -> Mono.error(new SubscriptionPublisherException(errors, exception)));
98-
})
99-
.contextWrite(snapshot::updateContext);
99+
});
100+
if (cancelledRequest != null) {
101+
subscriptionResult = subscriptionResult.takeUntilOther(cancelledRequest);
102+
}
103+
return subscriptionResult.contextWrite(snapshot::updateContext);
100104
}
101105

102106
value = ReactiveAdapterRegistryHelper.toMonoIfReactive(value);
103107

104108
if (value instanceof Mono<?> mono) {
109+
if (cancelledRequest != null) {
110+
mono = mono.takeUntilOther(cancelledRequest);
111+
}
105112
value = mono.contextWrite(snapshot::updateContext).toFuture();
106113
}
107114

spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io.micrometer.context.ContextSnapshotFactory;
3232
import org.dataloader.DataLoaderRegistry;
3333
import reactor.core.publisher.Mono;
34+
import reactor.core.publisher.Sinks;
3435

3536
import org.springframework.graphql.ExecutionGraphQlRequest;
3637
import org.springframework.graphql.ExecutionGraphQlResponse;
@@ -101,12 +102,15 @@ public final Mono<ExecutionGraphQlResponse> execute(ExecutionGraphQlRequest requ
101102
ContextSnapshotFactoryHelper.saveInstance(factory, graphQLContext);
102103
factory.captureFrom(contextView).updateContext(graphQLContext);
103104

105+
Sinks.Empty<Void> requestCancelled = Sinks.empty();
106+
graphQLContext.put(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono());
104107
ExecutionInput executionInputToUse = registerDataLoaders(executionInput);
105108

106109
return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(executionInputToUse))
107110
.onErrorResume((ex) -> ex instanceof GraphQLError, (ex) ->
108111
Mono.just(ExecutionResult.newExecutionResult().addError((GraphQLError) ex).build()))
109-
.map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result));
112+
.map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result))
113+
.doOnCancel(requestCancelled::tryEmitEmpty);
110114
});
111115
}
112116

spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,7 +19,9 @@
1919
import java.time.Duration;
2020
import java.util.Collections;
2121
import java.util.List;
22+
import java.util.Map;
2223
import java.util.concurrent.CompletableFuture;
24+
import java.util.concurrent.atomic.AtomicBoolean;
2325
import java.util.function.BiConsumer;
2426

2527
import graphql.ExecutionInput;
@@ -41,13 +43,16 @@
4143
import org.junit.jupiter.api.Test;
4244
import reactor.core.publisher.Flux;
4345
import reactor.core.publisher.Mono;
46+
import reactor.core.publisher.Sinks;
4447
import reactor.test.StepVerifier;
4548

49+
import org.springframework.graphql.ExecutionGraphQlRequest;
4650
import org.springframework.graphql.GraphQlSetup;
4751
import org.springframework.graphql.ResponseHelper;
4852
import org.springframework.graphql.TestThreadLocalAccessor;
4953

5054
import static org.assertj.core.api.Assertions.assertThat;
55+
import static org.awaitility.Awaitility.await;
5156

5257
/**
5358
* Tests for {@link ContextDataFetcherDecorator}.
@@ -257,4 +262,66 @@ void trivialDataFetcherIsNotDecorated() {
257262
assertThat(dataFetcher).isInstanceOf(TrivialDataFetcher.class);
258263
}
259264

265+
@Test
266+
void cancelMonoDataFetcherWhenRequestCancelled() throws Exception {
267+
AtomicBoolean dataFetcherCancelled = new AtomicBoolean();
268+
GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT)
269+
.queryFetcher("greeting", (env) ->
270+
Mono.just("Hello")
271+
.delayElement(Duration.ofSeconds(1))
272+
.doOnCancel(() -> dataFetcherCancelled.set(true))
273+
)
274+
.toGraphQl();
275+
276+
Sinks.Empty<Void> requestCancelled = Sinks.empty();
277+
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }")
278+
.graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build();
279+
280+
CompletableFuture<ExecutionResult> asyncResult = graphQl.executeAsync(input);
281+
requestCancelled.tryEmitEmpty();
282+
await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get);
283+
}
284+
285+
@Test
286+
void cancelFluxDataFetcherWhenRequestCancelled() throws Exception {
287+
AtomicBoolean dataFetcherCancelled = new AtomicBoolean();
288+
GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT)
289+
.queryFetcher("greeting", (env) ->
290+
Flux.just("Hello")
291+
.delayElements(Duration.ofSeconds(1))
292+
.doOnCancel(() -> dataFetcherCancelled.set(true))
293+
)
294+
.toGraphQl();
295+
296+
Sinks.Empty<Void> requestCancelled = Sinks.empty();
297+
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }")
298+
.graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build();
299+
300+
CompletableFuture<ExecutionResult> asyncResult = graphQl.executeAsync(input);
301+
requestCancelled.tryEmitEmpty();
302+
await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get);
303+
}
304+
305+
@Test
306+
void cancelFluxDataFetcherSubscriptionWhenRequestCancelled() throws Exception {
307+
AtomicBoolean dataFetcherCancelled = new AtomicBoolean();
308+
GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT)
309+
.subscriptionFetcher("greetings", (env) ->
310+
Flux.just("Hi", "Bonjour", "Hola")
311+
.delayElements(Duration.ofSeconds(1))
312+
.doOnCancel(() -> dataFetcherCancelled.set(true))
313+
)
314+
.toGraphQl();
315+
Sinks.Empty<Void> requestCancelled = Sinks.empty();
316+
ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }")
317+
.graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build();
318+
319+
ExecutionResult executionResult = graphQl.executeAsync(input).get();
320+
ResponseHelper.forSubscription(executionResult).subscribe();
321+
322+
requestCancelled.tryEmitEmpty();
323+
await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get);
324+
assertThat(dataFetcherCancelled).isTrue();
325+
}
326+
260327
}

spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java

+20
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
package org.springframework.graphql.execution;
1818

19+
import java.time.Duration;
1920
import java.util.Map;
21+
import java.util.concurrent.atomic.AtomicBoolean;
2022

2123
import graphql.ErrorType;
2224
import org.dataloader.DataLoaderRegistry;
2325
import org.junit.jupiter.api.Test;
2426
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
28+
import reactor.test.StepVerifier;
2529

2630
import org.springframework.graphql.Author;
2731
import org.springframework.graphql.Book;
@@ -77,4 +81,20 @@ void shouldHandleGraphQlErrors() {
7781
.hasFieldOrPropertyWithValue("errorType", ErrorType.ValidationError);
7882
}
7983

84+
@Test
85+
void cancellationSupport() {
86+
AtomicBoolean cancelled = new AtomicBoolean();
87+
Mono<String> greetingMono = Mono.just("hi")
88+
.delayElement(Duration.ofSeconds(3))
89+
.doOnCancel(() -> cancelled.set(true));
90+
91+
Mono<ExecutionGraphQlResponse> execution = GraphQlSetup.schemaContent("type Query { greeting: String }")
92+
.queryFetcher("greeting", (env) -> greetingMono)
93+
.toGraphQlService()
94+
.execute(TestExecutionRequest.forDocument("{ greeting }"));
95+
96+
StepVerifier.create(execution).thenCancel().verify();
97+
assertThat(cancelled).isTrue();
98+
}
99+
80100
}

0 commit comments

Comments
 (0)