Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error handler for transactional event #34146

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.transaction.TransactionManager;
import org.springframework.transaction.config.TransactionManagementConfigUtils;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.event.TransactionalEventListenerFactory;
import org.springframework.transaction.interceptor.RollbackRuleAttribute;
import org.springframework.transaction.interceptor.TransactionAttributeSource;
Expand Down Expand Up @@ -93,8 +94,11 @@ public TransactionAttributeSource transactionAttributeSource() {

@Bean(name = TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME)
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
public static TransactionalEventListenerFactory transactionalEventListenerFactory() {
return new RestrictedTransactionalEventListenerFactory();
public static TransactionalEventListenerFactory transactionalEventListenerFactory(@Nullable GlobalTransactionalEventErrorHandler errorHandler) {
if (errorHandler == null) {
return new RestrictedTransactionalEventListenerFactory();
}
return new RestrictedTransactionalEventListenerFactory(errorHandler);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.springframework.context.ApplicationListener;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.event.TransactionalEventListenerFactory;

/**
Expand All @@ -35,6 +36,14 @@
*/
public class RestrictedTransactionalEventListenerFactory extends TransactionalEventListenerFactory {

public RestrictedTransactionalEventListenerFactory() {
super();
}

public RestrictedTransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) {
super(errorHandler);
}

@Override
public ApplicationListener<?> createApplicationListener(String beanName, Class<?> type, Method method) {
Transactional txAnn = AnnotatedElementUtils.findMergedAnnotation(method, Transactional.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.springframework.transaction.config;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationEvent;
import org.springframework.transaction.event.TransactionalApplicationListener;

public abstract class GlobalTransactionalEventErrorHandler implements TransactionalApplicationListener.SynchronizationCallback {

public abstract void handle(ApplicationEvent event, @Nullable Throwable ex);

@Override
public void postProcessEvent(ApplicationEvent event, @Nullable Throwable ex) {
if (ex != null) {
handle(event, ex);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

import java.lang.reflect.Method;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.EventListenerFactory;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;

/**
* {@link EventListenerFactory} implementation that handles {@link TransactionalEventListener}
Expand All @@ -35,6 +37,13 @@ public class TransactionalEventListenerFactory implements EventListenerFactory,

private int order = 50;

private @Nullable GlobalTransactionalEventErrorHandler errorHandler;

public TransactionalEventListenerFactory() { }

public TransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) {
this.errorHandler = errorHandler;
}

public void setOrder(int order) {
this.order = order;
Expand All @@ -53,7 +62,14 @@ public boolean supportsMethod(Method method) {

@Override
public ApplicationListener<?> createApplicationListener(String beanName, Class<?> type, Method method) {
return new TransactionalApplicationListenerMethodAdapter(beanName, type, method);
if (errorHandler == null) {
return new TransactionalApplicationListenerMethodAdapter(beanName, type, method);
}
else {
TransactionalApplicationListenerMethodAdapter listener = new TransactionalApplicationListenerMethodAdapter(beanName, type, method);
listener.addCallback(errorHandler);
return listener;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import java.util.List;
import java.util.Map;

import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
Expand All @@ -43,6 +45,7 @@
import org.springframework.transaction.annotation.EnableTransactionManagement;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.springframework.transaction.support.TransactionTemplate;
Expand Down Expand Up @@ -99,12 +102,12 @@ void immediately() {
void immediatelyImpactsCurrentTransaction() {
load(ImmediateTestListener.class, BeforeCommitTestListener.class);
assertThatIllegalStateException().isThrownBy(() ->
this.transactionTemplate.execute(status -> {
getContext().publishEvent("FAIL");
throw new AssertionError("Should have thrown an exception at this point");
}))
.withMessageContaining("Test exception")
.withMessageContaining(EventCollector.IMMEDIATELY);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("FAIL");
throw new AssertionError("Should have thrown an exception at this point");
}))
.withMessageContaining("Test exception")
.withMessageContaining(EventCollector.IMMEDIATELY);

getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "FAIL");
getEventCollector().assertTotalEventsCount(1);
Expand Down Expand Up @@ -369,6 +372,45 @@ void conditionFoundOnMetaAnnotation() {
getEventCollector().assertNoEventReceived();
}

@Test
void afterCommitThrowException() {
doLoad(HandlerConfiguration.class, AfterCommitErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

@Test
void afterRollbackThrowException() {
doLoad(HandlerConfiguration.class, AfterRollbackErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
status.setRollbackOnly();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_ROLLBACK, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

@Test
void afterCompletionThrowException() {
doLoad(HandlerConfiguration.class, AfterCompletionErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_COMPLETION, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

protected EventCollector getEventCollector() {
return this.eventCollector;
Expand Down Expand Up @@ -442,6 +484,36 @@ public TransactionTemplate transactionTemplate() {
}
}

@Configuration
@EnableTransactionManagement
static class HandlerConfiguration {

@Bean
public EventCollector eventCollector() {
return new EventCollector();
}

@Bean
public TestBean testBean(ApplicationEventPublisher eventPublisher) {
return new TestBean(eventPublisher);
}

@Bean
public CallCountingTransactionManager transactionManager() {
return new CallCountingTransactionManager();
}

@Bean
public TransactionTemplate transactionTemplate() {
return new TransactionTemplate(transactionManager());
}

@Bean
public AfterRollbackErrorHandler errorHandler(ApplicationEventPublisher eventPublisher) {
return new AfterRollbackErrorHandler(eventPublisher);
}
}


@Configuration
static class MulticasterWithCustomExecutor {
Expand All @@ -467,7 +539,9 @@ static class EventCollector {

public static final String AFTER_ROLLBACK = "AFTER_ROLLBACK";

public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK};
public static final String HANDLE_ERROR = "HANDLE_ERROR";

public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK, HANDLE_ERROR};

private final MultiValueMap<String, Object> events = new LinkedMultiValueMap<>();

Expand All @@ -486,7 +560,7 @@ public void assertNoEventReceived(String... phases) {
for (String phase : phases) {
List<Object> eventsForPhase = getEvents(phase);
assertThat(eventsForPhase.size()).as("Expected no events for phase '" + phase + "' " +
"but got " + eventsForPhase + ":").isEqualTo(0);
"but got " + eventsForPhase + ":").isEqualTo(0);
}
}

Expand All @@ -504,7 +578,7 @@ public void assertTotalEventsCount(int number) {
size += entry.getValue().size();
}
assertThat(size).as("Wrong number of total events (" + this.events.size() + ") " +
"registered phase(s)").isEqualTo(number);
"registered phase(s)").isEqualTo(number);
}
}

Expand Down Expand Up @@ -677,6 +751,51 @@ public void handleAfterCommit(String data) {
}


@Component
static class AfterCommitErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_COMMIT, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_COMMIT, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

@Component
static class AfterRollbackErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_ROLLBACK, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_ROLLBACK, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

@Component
static class AfterCompletionErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_COMPLETION, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_COMPLETION, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

static class EventTransactionSynchronization implements TransactionSynchronization {

private final int order;
Expand All @@ -691,4 +810,18 @@ public int getOrder() {
}
}

static class AfterRollbackErrorHandler extends GlobalTransactionalEventErrorHandler {

private final ApplicationEventPublisher eventPublisher;

AfterRollbackErrorHandler(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
}

@Override
public void handle(ApplicationEvent event, @Nullable Throwable ex) {
eventPublisher.publishEvent("HANDLE_ERROR");
}
}

}