diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java index 8f2dfb724a1b..e3a06ff5e130 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java @@ -30,6 +30,7 @@ import org.springframework.core.type.AnnotationMetadata; import org.springframework.transaction.TransactionManager; import org.springframework.transaction.config.TransactionManagementConfigUtils; +import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler; import org.springframework.transaction.event.TransactionalEventListenerFactory; import org.springframework.transaction.interceptor.RollbackRuleAttribute; import org.springframework.transaction.interceptor.TransactionAttributeSource; @@ -93,8 +94,11 @@ public TransactionAttributeSource transactionAttributeSource() { @Bean(name = TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME) @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - public static TransactionalEventListenerFactory transactionalEventListenerFactory() { - return new RestrictedTransactionalEventListenerFactory(); + public static TransactionalEventListenerFactory transactionalEventListenerFactory(@Nullable GlobalTransactionalEventErrorHandler errorHandler) { + if (errorHandler == null) { + return new RestrictedTransactionalEventListenerFactory(); + } + return new RestrictedTransactionalEventListenerFactory(errorHandler); } } diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java index e1473eb411e1..ca51ec46d91c 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java @@ -20,6 +20,7 @@ import org.springframework.context.ApplicationListener; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler; import org.springframework.transaction.event.TransactionalEventListenerFactory; /** @@ -35,6 +36,14 @@ */ public class RestrictedTransactionalEventListenerFactory extends TransactionalEventListenerFactory { + public RestrictedTransactionalEventListenerFactory() { + super(); + } + + public RestrictedTransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) { + super(errorHandler); + } + @Override public ApplicationListener createApplicationListener(String beanName, Class type, Method method) { Transactional txAnn = AnnotatedElementUtils.findMergedAnnotation(method, Transactional.class); diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/GlobalTransactionalEventErrorHandler.java b/spring-tx/src/main/java/org/springframework/transaction/config/GlobalTransactionalEventErrorHandler.java new file mode 100644 index 000000000000..cbc8bf929122 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/GlobalTransactionalEventErrorHandler.java @@ -0,0 +1,18 @@ +package org.springframework.transaction.config; + +import org.jspecify.annotations.Nullable; +import org.springframework.context.ApplicationEvent; +import org.springframework.transaction.event.TransactionalApplicationListener; + +public abstract class GlobalTransactionalEventErrorHandler implements TransactionalApplicationListener.SynchronizationCallback { + + public abstract void handle(ApplicationEvent event, @Nullable Throwable ex); + + @Override + public void postProcessEvent(ApplicationEvent event, @Nullable Throwable ex) { + if (ex != null) { + handle(event, ex); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java index 305d260b4af1..21523d603926 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java +++ b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java @@ -18,10 +18,12 @@ import java.lang.reflect.Method; +import org.jspecify.annotations.Nullable; import org.springframework.context.ApplicationListener; import org.springframework.context.event.EventListenerFactory; import org.springframework.core.Ordered; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler; /** * {@link EventListenerFactory} implementation that handles {@link TransactionalEventListener} @@ -35,6 +37,13 @@ public class TransactionalEventListenerFactory implements EventListenerFactory, private int order = 50; + private @Nullable GlobalTransactionalEventErrorHandler errorHandler; + + public TransactionalEventListenerFactory() { } + + public TransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) { + this.errorHandler = errorHandler; + } public void setOrder(int order) { this.order = order; @@ -53,7 +62,14 @@ public boolean supportsMethod(Method method) { @Override public ApplicationListener createApplicationListener(String beanName, Class type, Method method) { - return new TransactionalApplicationListenerMethodAdapter(beanName, type, method); + if (errorHandler == null) { + return new TransactionalApplicationListenerMethodAdapter(beanName, type, method); + } + else { + TransactionalApplicationListenerMethodAdapter listener = new TransactionalApplicationListenerMethodAdapter(beanName, type, method); + listener.addCallback(errorHandler); + return listener; + } } } diff --git a/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java b/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java index 4cf080a82fe3..7254e582bb61 100644 --- a/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java +++ b/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java @@ -26,10 +26,12 @@ import java.util.List; import java.util.Map; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; @@ -43,6 +45,7 @@ import org.springframework.transaction.annotation.EnableTransactionManagement; import org.springframework.transaction.annotation.Propagation; import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler; import org.springframework.transaction.support.TransactionSynchronization; import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.transaction.support.TransactionTemplate; @@ -99,12 +102,12 @@ void immediately() { void immediatelyImpactsCurrentTransaction() { load(ImmediateTestListener.class, BeforeCommitTestListener.class); assertThatIllegalStateException().isThrownBy(() -> - this.transactionTemplate.execute(status -> { - getContext().publishEvent("FAIL"); - throw new AssertionError("Should have thrown an exception at this point"); - })) - .withMessageContaining("Test exception") - .withMessageContaining(EventCollector.IMMEDIATELY); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("FAIL"); + throw new AssertionError("Should have thrown an exception at this point"); + })) + .withMessageContaining("Test exception") + .withMessageContaining(EventCollector.IMMEDIATELY); getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "FAIL"); getEventCollector().assertTotalEventsCount(1); @@ -369,6 +372,45 @@ void conditionFoundOnMetaAnnotation() { getEventCollector().assertNoEventReceived(); } + @Test + void afterCommitThrowException() { + doLoad(HandlerConfiguration.class, AfterCommitErrorHandlerTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test"); + getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR"); + getEventCollector().assertTotalEventsCount(2); + } + + @Test + void afterRollbackThrowException() { + doLoad(HandlerConfiguration.class, AfterRollbackErrorHandlerTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + status.setRollbackOnly(); + return null; + }); + getEventCollector().assertEvents(EventCollector.AFTER_ROLLBACK, "test"); + getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR"); + getEventCollector().assertTotalEventsCount(2); + } + + @Test + void afterCompletionThrowException() { + doLoad(HandlerConfiguration.class, AfterCompletionErrorHandlerTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMPLETION, "test"); + getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR"); + getEventCollector().assertTotalEventsCount(2); + } protected EventCollector getEventCollector() { return this.eventCollector; @@ -442,6 +484,36 @@ public TransactionTemplate transactionTemplate() { } } + @Configuration + @EnableTransactionManagement + static class HandlerConfiguration { + + @Bean + public EventCollector eventCollector() { + return new EventCollector(); + } + + @Bean + public TestBean testBean(ApplicationEventPublisher eventPublisher) { + return new TestBean(eventPublisher); + } + + @Bean + public CallCountingTransactionManager transactionManager() { + return new CallCountingTransactionManager(); + } + + @Bean + public TransactionTemplate transactionTemplate() { + return new TransactionTemplate(transactionManager()); + } + + @Bean + public AfterRollbackErrorHandler errorHandler(ApplicationEventPublisher eventPublisher) { + return new AfterRollbackErrorHandler(eventPublisher); + } + } + @Configuration static class MulticasterWithCustomExecutor { @@ -467,7 +539,9 @@ static class EventCollector { public static final String AFTER_ROLLBACK = "AFTER_ROLLBACK"; - public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK}; + public static final String HANDLE_ERROR = "HANDLE_ERROR"; + + public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK, HANDLE_ERROR}; private final MultiValueMap events = new LinkedMultiValueMap<>(); @@ -486,7 +560,7 @@ public void assertNoEventReceived(String... phases) { for (String phase : phases) { List eventsForPhase = getEvents(phase); assertThat(eventsForPhase.size()).as("Expected no events for phase '" + phase + "' " + - "but got " + eventsForPhase + ":").isEqualTo(0); + "but got " + eventsForPhase + ":").isEqualTo(0); } } @@ -504,7 +578,7 @@ public void assertTotalEventsCount(int number) { size += entry.getValue().size(); } assertThat(size).as("Wrong number of total events (" + this.events.size() + ") " + - "registered phase(s)").isEqualTo(number); + "registered phase(s)").isEqualTo(number); } } @@ -677,6 +751,51 @@ public void handleAfterCommit(String data) { } + @Component + static class AfterCommitErrorHandlerTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = AFTER_COMMIT, condition = "!'HANDLE_ERROR'.equals(#data)") + public void handleBeforeCommit(String data) { + handleEvent(EventCollector.AFTER_COMMIT, data); + throw new IllegalStateException("test"); + } + + @EventListener(condition = "'HANDLE_ERROR'.equals(#data)") + public void handleImmediately(String data) { + handleEvent(EventCollector.HANDLE_ERROR, data); + } + } + + @Component + static class AfterRollbackErrorHandlerTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = AFTER_ROLLBACK, condition = "!'HANDLE_ERROR'.equals(#data)") + public void handleBeforeCommit(String data) { + handleEvent(EventCollector.AFTER_ROLLBACK, data); + throw new IllegalStateException("test"); + } + + @EventListener(condition = "'HANDLE_ERROR'.equals(#data)") + public void handleImmediately(String data) { + handleEvent(EventCollector.HANDLE_ERROR, data); + } + } + + @Component + static class AfterCompletionErrorHandlerTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = AFTER_COMPLETION, condition = "!'HANDLE_ERROR'.equals(#data)") + public void handleBeforeCommit(String data) { + handleEvent(EventCollector.AFTER_COMPLETION, data); + throw new IllegalStateException("test"); + } + + @EventListener(condition = "'HANDLE_ERROR'.equals(#data)") + public void handleImmediately(String data) { + handleEvent(EventCollector.HANDLE_ERROR, data); + } + } + static class EventTransactionSynchronization implements TransactionSynchronization { private final int order; @@ -691,4 +810,18 @@ public int getOrder() { } } + static class AfterRollbackErrorHandler extends GlobalTransactionalEventErrorHandler { + + private final ApplicationEventPublisher eventPublisher; + + AfterRollbackErrorHandler(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + } + + @Override + public void handle(ApplicationEvent event, @Nullable Throwable ex) { + eventPublisher.publishEvent("HANDLE_ERROR"); + } + } + }