Skip to content

Commit 521cda0

Browse files
committed
Use non-blocking thread in WebFlux controller with RequestBody parameter
This commit ensures that `InvocableHandlerMethod` executes the method on the desired thread if a non-blocking thread is specified, even in the case where arguments resolution happens on a different thread. This is notably the case if the method body is resolved as an input argument to the controller method (`@RequestBody`). Closes gh-32502
1 parent 08c9b5c commit 521cda0

File tree

9 files changed

+216
-23
lines changed

9 files changed

+216
-23
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java

+24-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import kotlin.reflect.jvm.KCallablesJvm;
3737
import kotlin.reflect.jvm.ReflectJvmMapping;
3838
import reactor.core.publisher.Mono;
39+
import reactor.core.scheduler.Scheduler;
3940

4041
import org.springframework.core.CoroutinesUtils;
4142
import org.springframework.core.DefaultParameterNameDiscoverer;
@@ -59,6 +60,12 @@
5960
* Extension of {@link HandlerMethod} that invokes the underlying method with
6061
* argument values resolved from the current HTTP request through a list of
6162
* {@link HandlerMethodArgumentResolver}.
63+
* <p>By default, the method invocation happens on the thread from which the
64+
* {@code Mono} was subscribed to, or in some cases the thread that emitted one
65+
* of the resolved arguments (e.g. when the request body needs to be decoded).
66+
* To ensure a predictable thread for the underlying method's invocation,
67+
* a {@link Scheduler} can optionally be provided via
68+
* {@link #setInvocationScheduler(Scheduler)}.
6269
*
6370
* @author Rossen Stoyanchev
6471
* @author Juergen Hoeller
@@ -85,6 +92,9 @@ public class InvocableHandlerMethod extends HandlerMethod {
8592

8693
private Class<?>[] validationGroups = EMPTY_GROUPS;
8794

95+
@Nullable
96+
private Scheduler invocationScheduler;
97+
8898

8999
/**
90100
* Create an instance from a {@code HandlerMethod}.
@@ -153,6 +163,13 @@ public void setMethodValidator(@Nullable MethodValidator methodValidator) {
153163
methodValidator.determineValidationGroups(getBean(), getBridgedMethod()) : EMPTY_GROUPS);
154164
}
155165

166+
/**
167+
* Set the {@link Scheduler} on which to perform the method invocation.
168+
* @since 6.1.6
169+
*/
170+
public void setInvocationScheduler(@Nullable Scheduler invocationScheduler) {
171+
this.invocationScheduler = invocationScheduler;
172+
}
156173

157174
/**
158175
* Invoke the method for the given exchange.
@@ -165,7 +182,7 @@ public void setMethodValidator(@Nullable MethodValidator methodValidator) {
165182
public Mono<HandlerResult> invoke(
166183
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {
167184

168-
return getMethodArgumentValues(exchange, bindingContext, providedArgs).flatMap(args -> {
185+
return getMethodArgumentValuesOnScheduler(exchange, bindingContext, providedArgs).flatMap(args -> {
169186
if (shouldValidateArguments() && this.methodValidator != null) {
170187
this.methodValidator.applyArgumentValidation(
171188
getBean(), getBridgedMethod(), getMethodParameters(), args, this.validationGroups);
@@ -217,14 +234,19 @@ public Mono<HandlerResult> invoke(
217234
});
218235
}
219236

237+
private Mono<Object[]> getMethodArgumentValuesOnScheduler(
238+
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {
239+
Mono<Object[]> argumentValuesMono = getMethodArgumentValues(exchange, bindingContext, providedArgs);
240+
return this.invocationScheduler != null ? argumentValuesMono.publishOn(this.invocationScheduler) : argumentValuesMono;
241+
}
242+
220243
private Mono<Object[]> getMethodArgumentValues(
221244
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {
222245

223246
MethodParameter[] parameters = getMethodParameters();
224247
if (ObjectUtils.isEmpty(parameters)) {
225248
return EMPTY_ARGS;
226249
}
227-
228250
List<Mono<Object>> argMonos = new ArrayList<>(parameters.length);
229251
for (MethodParameter parameter : parameters) {
230252
parameter.initParameterNameDiscovery(this.parameterNameDiscoverer);

spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java

+30-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import org.apache.commons.logging.Log;
3030
import org.apache.commons.logging.LogFactory;
31+
import reactor.core.scheduler.Scheduler;
3132

3233
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
3334
import org.springframework.context.ApplicationContext;
@@ -121,11 +122,19 @@ class ControllerMethodResolver {
121122

122123
private final Map<Class<?>, SessionAttributesHandler> sessionAttributesHandlerCache = new ConcurrentHashMap<>(64);
123124

125+
@Nullable
126+
private final Scheduler invocationScheduler;
127+
128+
@Nullable
129+
private final Predicate<? super HandlerMethod> blockingMethodPredicate;
130+
124131

125132
ControllerMethodResolver(
126133
ArgumentResolverConfigurer customResolvers, ReactiveAdapterRegistry adapterRegistry,
127134
ConfigurableApplicationContext context, List<HttpMessageReader<?>> readers,
128-
@Nullable WebBindingInitializer webBindingInitializer) {
135+
@Nullable WebBindingInitializer webBindingInitializer,
136+
@Nullable Scheduler invocationScheduler,
137+
@Nullable Predicate<? super HandlerMethod> blockingMethodPredicate) {
129138

130139
Assert.notNull(customResolvers, "ArgumentResolverConfigurer is required");
131140
Assert.notNull(adapterRegistry, "ReactiveAdapterRegistry is required");
@@ -137,6 +146,8 @@ class ControllerMethodResolver {
137146
this.requestMappingResolvers = requestMappingResolvers(customResolvers, adapterRegistry, context, readers);
138147
this.exceptionHandlerResolvers = exceptionHandlerResolvers(customResolvers, adapterRegistry, context);
139148
this.reactiveAdapterRegistry = adapterRegistry;
149+
this.invocationScheduler = invocationScheduler;
150+
this.blockingMethodPredicate = blockingMethodPredicate;
140151

141152
if (BEAN_VALIDATION_PRESENT) {
142153
this.methodValidator = HandlerMethodValidator.from(webBindingInitializer, null,
@@ -287,6 +298,21 @@ private static Predicate<MethodParameter> methodParamPredicate(
287298
};
288299
}
289300

301+
/**
302+
* Return a {@link Scheduler} for the given method if it is considered
303+
* blocking by the underlying blocking method predicate, or null if no
304+
* particular scheduler should be used for this method invocation.
305+
*/
306+
@Nullable
307+
public Scheduler getSchedulerFor(HandlerMethod handlerMethod) {
308+
if (this.invocationScheduler != null) {
309+
Assert.state(this.blockingMethodPredicate != null, "Expected HandlerMethod Predicate");
310+
if (this.blockingMethodPredicate.test(handlerMethod)) {
311+
return this.invocationScheduler;
312+
}
313+
}
314+
return null;
315+
}
290316

291317
/**
292318
* Return an {@link InvocableHandlerMethod} for the given
@@ -297,6 +323,9 @@ public InvocableHandlerMethod getRequestMappingMethod(HandlerMethod handlerMetho
297323
invocable.setArgumentResolvers(this.requestMappingResolvers);
298324
invocable.setReactiveAdapterRegistry(this.reactiveAdapterRegistry);
299325
invocable.setMethodValidator(this.methodValidator);
326+
//getSchedulerFor returns null if not applicable, which is ok here
327+
invocable.setInvocationScheduler(getSchedulerFor(handlerMethod));
328+
300329
return invocable;
301330
}
302331

spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ public void afterPropertiesSet() throws Exception {
225225

226226
this.methodResolver = new ControllerMethodResolver(
227227
this.argumentResolverConfigurer, this.reactiveAdapterRegistry, this.applicationContext,
228-
this.messageReaders, this.webBindingInitializer);
228+
this.messageReaders, this.webBindingInitializer,
229+
this.scheduler, this.blockingMethodPredicate);
229230

230231
this.modelInitializer = new ModelInitializer(this.methodResolver, this.reactiveAdapterRegistry);
231232
}
@@ -260,11 +261,9 @@ public Mono<HandlerResult> handle(ServerWebExchange exchange, Object handler) {
260261
.doOnNext(result -> result.setExceptionHandler(exceptionHandler))
261262
.onErrorResume(ex -> exceptionHandler.handleError(exchange, ex));
262263

263-
if (this.scheduler != null) {
264-
Assert.state(this.blockingMethodPredicate != null, "Expected HandlerMethod Predicate");
265-
if (this.blockingMethodPredicate.test(handlerMethod)) {
266-
resultMono = resultMono.subscribeOn(this.scheduler);
267-
}
264+
Scheduler optionalScheduler = this.methodResolver.getSchedulerFor(handlerMethod);
265+
if (optionalScheduler != null) {
266+
return resultMono.subscribeOn(optionalScheduler);
268267
}
269268

270269
return resultMono;

spring-webflux/src/test/java/org/springframework/web/reactive/result/method/InvocableHandlerMethodTests.java

+46-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.junit.jupiter.api.Test;
2727
import reactor.core.publisher.Flux;
2828
import reactor.core.publisher.Mono;
29+
import reactor.core.scheduler.Scheduler;
30+
import reactor.core.scheduler.Schedulers;
2931
import reactor.test.StepVerifier;
3032

3133
import org.springframework.core.io.buffer.DataBuffer;
@@ -75,6 +77,15 @@ void resolveArg() {
7577
assertHandlerResultValue(mono, "success:value1");
7678
}
7779

80+
@Test
81+
void resolveArgOnSchedulerThread() {
82+
this.resolvers.add(stubResolver(Mono.<Object>just("success").publishOn(Schedulers.newSingle("wrong"))));
83+
Method method = ResolvableMethod.on(TestController.class).mockCall(o -> o.singleArgThread(null)).method();
84+
Mono<HandlerResult> mono = invokeOnScheduler(Schedulers.newSingle("good"), new TestController(), method);
85+
86+
assertHandlerResultValue(mono, "success on thread: good-", false);
87+
}
88+
7889
@Test
7990
void resolveNoArgValue() {
8091
this.resolvers.add(stubResolver(Mono.empty()));
@@ -92,6 +103,14 @@ void resolveNoArgs() {
92103
assertHandlerResultValue(mono, "success");
93104
}
94105

106+
@Test
107+
void resolveNoArgsOnSchedulerThread() {
108+
Method method = ResolvableMethod.on(TestController.class).mockCall(o -> o.noArgsThread()).method();
109+
Mono<HandlerResult> mono = invokeOnScheduler(Schedulers.newSingle("good"), new TestController(), method);
110+
111+
assertHandlerResultValue(mono, "on thread: good-", false);
112+
}
113+
95114
@Test
96115
void cannotResolveArg() {
97116
Method method = ResolvableMethod.on(TestController.class).mockCall(o -> o.singleArg(null)).method();
@@ -229,6 +248,13 @@ private Mono<HandlerResult> invoke(Object handler, Method method, Object... prov
229248
return invocable.invoke(this.exchange, new BindingContext(), providedArgs);
230249
}
231250

251+
private Mono<HandlerResult> invokeOnScheduler(Scheduler scheduler, Object handler, Method method, Object... providedArgs) {
252+
InvocableHandlerMethod invocable = new InvocableHandlerMethod(handler, method);
253+
invocable.setArgumentResolvers(this.resolvers);
254+
invocable.setInvocationScheduler(scheduler);
255+
return invocable.invoke(this.exchange, new BindingContext(), providedArgs);
256+
}
257+
232258
private HandlerMethodArgumentResolver stubResolver(Object stubValue) {
233259
return stubResolver(Mono.just(stubValue));
234260
}
@@ -241,8 +267,19 @@ private HandlerMethodArgumentResolver stubResolver(Mono<Object> stubValue) {
241267
}
242268

243269
private void assertHandlerResultValue(Mono<HandlerResult> mono, String expected) {
270+
this.assertHandlerResultValue(mono, expected, true);
271+
}
272+
273+
private void assertHandlerResultValue(Mono<HandlerResult> mono, String expected, boolean strict) {
244274
StepVerifier.create(mono)
245-
.consumeNextWith(result -> assertThat(result.getReturnValue()).isEqualTo(expected))
275+
.assertNext(result -> {
276+
if (strict) {
277+
assertThat(result.getReturnValue()).isEqualTo(expected);
278+
}
279+
else {
280+
assertThat(String.valueOf(result.getReturnValue())).startsWith(expected);
281+
}
282+
})
246283
.expectComplete()
247284
.verify();
248285
}
@@ -259,6 +296,14 @@ String noArgs() {
259296
return "success";
260297
}
261298

299+
String singleArgThread(String q) {
300+
return q + " on thread: " + Thread.currentThread().getName();
301+
}
302+
303+
String noArgsThread() {
304+
return "on thread: " + Thread.currentThread().getName();
305+
}
306+
262307
void exceptionMethod() {
263308
throw new IllegalStateException("boo");
264309
}

0 commit comments

Comments
 (0)