Skip to content

GH-1210: Add Kotlin suspend functions support (#2460) #2519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ ext {
jaywayJsonPathVersion = '2.4.0'
junit4Version = '4.13.2'
junitJupiterVersion = '5.8.2'
kotlinCoroutinesVersion = '1.6.4'
log4jVersion = '2.17.2'
logbackVersion = '1.2.3'
lz4Version = '1.8.0'
Expand Down Expand Up @@ -379,6 +380,7 @@ project('spring-rabbit') {
}
optionalApi "com.jayway.jsonpath:json-path:$jaywayJsonPathVersion"
optionalApi "org.apache.commons:commons-pool2:$commonsPoolVersion"
optionalApi "org.jetbrains.kotlinx:kotlinx-coroutines-reactor:$kotlinCoroutinesVersion"

testApi project(':spring-rabbit-junit')
testImplementation("com.willowtreeapps.assertk:assertk-jvm:$assertkVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
Expand All @@ -30,7 +28,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand All @@ -56,6 +53,7 @@
import org.springframework.amqp.rabbit.listener.RabbitListenerContainerFactory;
import org.springframework.amqp.rabbit.listener.RabbitListenerEndpointRegistrar;
import org.springframework.amqp.rabbit.listener.RabbitListenerEndpointRegistry;
import org.springframework.amqp.rabbit.listener.adapter.AmqpMessageHandlerMethodFactory;
import org.springframework.amqp.rabbit.listener.adapter.ReplyPostProcessor;
import org.springframework.amqp.rabbit.listener.api.RabbitListenerErrorHandler;
import org.springframework.amqp.support.converter.MessageConverter;
Expand All @@ -76,7 +74,6 @@
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.expression.StandardBeanExpressionResolver;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.MergedAnnotations;
Expand All @@ -88,21 +85,16 @@
import org.springframework.core.task.TaskExecutor;
import org.springframework.format.support.DefaultFormattingConversionService;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.GenericMessageConverter;
import org.springframework.messaging.handler.annotation.support.DefaultMessageHandlerMethodFactory;
import org.springframework.messaging.handler.annotation.support.MessageHandlerMethodFactory;
import org.springframework.messaging.handler.annotation.support.MethodArgumentNotValidException;
import org.springframework.messaging.handler.annotation.support.PayloadMethodArgumentResolver;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.handler.invocation.InvocableHandlerMethod;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.Validator;

/**
Expand Down Expand Up @@ -440,14 +432,10 @@ protected Collection<Declarable> processListener(MethodRabbitListenerEndpoint en
List<Object> resolvedQueues = resolveQueues(rabbitListener, declarables);
if (!resolvedQueues.isEmpty()) {
if (resolvedQueues.get(0) instanceof String) {
endpoint.setQueueNames(resolvedQueues.stream()
.map(o -> (String) o)
.collect(Collectors.toList()).toArray(new String[0]));
endpoint.setQueueNames(resolvedQueues.stream().map(o -> (String) o).toArray(String[]::new));
}
else {
endpoint.setQueues(resolvedQueues.stream()
.map(o -> (Queue) o)
.collect(Collectors.toList()).toArray(new Queue[0]));
endpoint.setQueues(resolvedQueues.stream().map(o -> (Queue) o).toArray(Queue[]::new));
}
}
endpoint.setConcurrency(resolveExpressionAsStringOrInteger(rabbitListener.concurrency(), "concurrency"));
Expand Down Expand Up @@ -664,12 +652,10 @@ private List<Object> resolveQueues(RabbitListener rabbitListener, Collection<Dec
String[] queues = rabbitListener.queues();
QueueBinding[] bindings = rabbitListener.bindings();
org.springframework.amqp.rabbit.annotation.Queue[] queuesToDeclare = rabbitListener.queuesToDeclare();
List<String> queueNames = new ArrayList<String>();
List<Queue> queueBeans = new ArrayList<Queue>();
if (queues.length > 0) {
for (int i = 0; i < queues.length; i++) {
resolveQueues(queues[i], queueNames, queueBeans);
}
List<String> queueNames = new ArrayList<>();
List<Queue> queueBeans = new ArrayList<>();
for (String queue : queues) {
resolveQueues(queue, queueNames, queueBeans);
}
if (!queueNames.isEmpty()) {
// revert to the previous behavior of just using the name when there is mixture of String and Queue
Expand All @@ -681,8 +667,8 @@ private List<Object> resolveQueues(RabbitListener rabbitListener, Collection<Dec
throw new BeanInitializationException(
"@RabbitListener can have only one of 'queues', 'queuesToDeclare', or 'bindings'");
}
for (int i = 0; i < queuesToDeclare.length; i++) {
queueNames.add(declareQueue(queuesToDeclare[i], declarables));
for (org.springframework.amqp.rabbit.annotation.Queue queue : queuesToDeclare) {
queueNames.add(declareQueue(queue, declarables));
}
}
if (bindings.length > 0) {
Expand Down Expand Up @@ -752,7 +738,7 @@ private String[] registerBeansForDeclaration(RabbitListener rabbitListener, Coll
declareExchangeAndBinding(binding, queueName, declarables);
}
}
return queues.toArray(new String[queues.size()]);
return queues.toArray(new String[0]);
}

private String declareQueue(org.springframework.amqp.rabbit.annotation.Queue bindingQueue,
Expand Down Expand Up @@ -859,7 +845,7 @@ private void registerBindings(QueueBinding binding, String queueName, String exc
}

private Map<String, Object> resolveArguments(Argument[] arguments) {
Map<String, Object> map = new HashMap<String, Object>();
Map<String, Object> map = new HashMap<>();
for (Argument arg : arguments) {
String key = resolveExpressionAsString(arg.name(), "@Argument.name");
if (StringUtils.hasText(key)) {
Expand Down Expand Up @@ -1025,7 +1011,7 @@ private MessageHandlerMethodFactory getFactory() {
}

private MessageHandlerMethodFactory createDefaultMessageHandlerMethodFactory() {
DefaultMessageHandlerMethodFactory defaultFactory = new DefaultMessageHandlerMethodFactory();
DefaultMessageHandlerMethodFactory defaultFactory = new AmqpMessageHandlerMethodFactory();
Validator validator = RabbitListenerAnnotationBeanPostProcessor.this.registrar.getValidator();
if (validator != null) {
defaultFactory.setValidator(validator);
Expand All @@ -1038,74 +1024,14 @@ private MessageHandlerMethodFactory createDefaultMessageHandlerMethodFactory() {
List<HandlerMethodArgumentResolver> customArgumentsResolver = new ArrayList<>(
RabbitListenerAnnotationBeanPostProcessor.this.registrar.getCustomMethodArgumentResolvers());
defaultFactory.setCustomArgumentResolvers(customArgumentsResolver);
GenericMessageConverter messageConverter = new GenericMessageConverter(
this.defaultFormattingConversionService);
defaultFactory.setMessageConverter(messageConverter);
// Has to be at the end - look at PayloadMethodArgumentResolver documentation
customArgumentsResolver.add(new OptionalEmptyAwarePayloadArgumentResolver(messageConverter, validator));
defaultFactory.setMessageConverter(new GenericMessageConverter(this.defaultFormattingConversionService));

defaultFactory.afterPropertiesSet();
return defaultFactory;
}

}

private static class OptionalEmptyAwarePayloadArgumentResolver extends PayloadMethodArgumentResolver {

OptionalEmptyAwarePayloadArgumentResolver(
org.springframework.messaging.converter.MessageConverter messageConverter,
@Nullable Validator validator) {

super(messageConverter, validator);
}

@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception { // NOSONAR
Object resolved = null;
try {
resolved = super.resolveArgument(parameter, message);
}
catch (MethodArgumentNotValidException ex) {
Type type = parameter.getGenericParameterType();
if (isOptional(message, type)) {
BindingResult bindingResult = ex.getBindingResult();
if (bindingResult != null) {
List<ObjectError> allErrors = bindingResult.getAllErrors();
if (allErrors.size() == 1) {
String defaultMessage = allErrors.get(0).getDefaultMessage();
if ("Payload value must not be empty".equals(defaultMessage)) {
return Optional.empty();
}
}
}
}
throw ex;
}
/*
* Replace Optional.empty() list elements with null.
*/
if (resolved instanceof List) {
List<?> list = ((List<?>) resolved);
for (int i = 0; i < list.size(); i++) {
if (list.get(i).equals(Optional.empty())) {
list.set(i, null);
}
}
}
return resolved;
}

private boolean isOptional(Message<?> message, Type type) {
return (Optional.class.equals(type) || (type instanceof ParameterizedType
&& Optional.class.equals(((ParameterizedType) type).getRawType())))
&& message.getPayload().equals(Optional.empty());
}

@Override
protected boolean isEmptyPayload(Object payload) {
return payload == null || payload.equals(Optional.empty());
}

}
/**
* The metadata holder of the class with {@link RabbitListener}
* and {@link RabbitHandler} annotations.
Expand Down Expand Up @@ -1145,6 +1071,9 @@ private TypeMetadata() {

/**
* A method annotated with {@link RabbitListener}, together with the annotations.
*
* @param method the method with annotations
* @param annotations on the method
*/
private static class ListenerMethod {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ protected void handleResult(InvocationResult resultArg, Message request, Channel
* response message back.
* @param resultArg the result object to handle (never <code>null</code>)
* @param request the original request message
* @param channel the Rabbit channel to operate on (may be <code>null</code>)
* @param channel the Rabbit channel to operate on (maybe <code>null</code>)
* @param source the source data for the method invocation - e.g.
* {@code o.s.messaging.Message<?>}; may be null
* @see #buildMessage
Expand Down Expand Up @@ -404,8 +404,8 @@ else if (resultArg.getReturnValue() instanceof CompletableFuture) {
}
else if (monoPresent && MonoHandler.isMono(resultArg.getReturnValue())) {
if (!this.isManualAck) {
this.logger.warn("Container AcknowledgeMode must be MANUAL for a Mono<?> return type; "
+ "otherwise the container will ack the message immediately");
this.logger.warn("Container AcknowledgeMode must be MANUAL for a Mono<?> return type" +
"(or Kotlin suspend function); otherwise the container will ack the message immediately");
}
MonoHandler.subscribe(resultArg.getReturnValue(),
r -> asyncSuccess(resultArg, request, channel, source, r),
Expand Down Expand Up @@ -461,7 +461,7 @@ private void basicAck(Message request, Channel channel) {
}

private void asyncFailure(Message request, Channel channel, Throwable t) {
this.logger.error("Future or Mono was completed with an exception for " + request, t);
this.logger.error("Future, Mono, or suspend function was completed with an exception for " + request, t);
try {
channel.basicNack(request.getMessageProperties().getDeliveryTag(), false,
ContainerUtils.shouldRequeue(this.defaultRequeueRejected, t, this.logger));
Expand Down
Loading