diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachine.java index cf6ae07e3..8c7cc3de8 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachine.java @@ -1,5 +1,5 @@ /* - * Copyright 2015 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,5 +64,4 @@ public interface StateMachine extends Region { * @return true, if error has been set */ boolean hasStateMachineError(); - } diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachineEventResult.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachineEventResult.java index 0337b068e..13f5c1ef1 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachineEventResult.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/StateMachineEventResult.java @@ -1,11 +1,11 @@ /* - * Copyright 2019-2020 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,6 +15,7 @@ */ package org.springframework.statemachine; +import java.util.Optional; import org.springframework.messaging.Message; import org.springframework.statemachine.region.Region; @@ -59,6 +60,14 @@ public interface StateMachineEventResult { */ Mono complete(); + /** + * If there was an exception that caused the transition to be denied - return that + * @return Optional Throwable that caused the transition to be denied + */ + default Optional getDenialCause() { + return Optional.empty(); + }; + /** * Enumeration of a result type indicating whether a region accepted, denied or * deferred an event. @@ -82,7 +91,7 @@ public enum ResultType { */ public static StateMachineEventResult from(Region region, Message message, ResultType resultType) { - return new DefaultStateMachineEventResult<>(region, message, resultType, null); + return new DefaultStateMachineEventResult<>(region, message, resultType, null, null); } @@ -100,7 +109,24 @@ public static StateMachineEventResult from(Region region, Mes */ public static StateMachineEventResult from(Region region, Message message, ResultType resultType, Mono complete) { - return new DefaultStateMachineEventResult<>(region, message, resultType, complete); + return new DefaultStateMachineEventResult<>(region, message, resultType, complete, null); + } + + /** + * Create a {@link StateMachineEventResult} from a {@link Region}, + * {@link Message} and a {@link ResultType}. + * + * @param the type of state + * @param the type of event + * @param region the region + * @param message the message + * @param resultType the result type + * @param denialCause the throwable (that most likely caused transition denial) + * @return the state machine event result + */ + public static StateMachineEventResult from(Region region, Message message, + ResultType resultType, Throwable denialCause) { + return new DefaultStateMachineEventResult<>(region, message, resultType, null, denialCause); } static class DefaultStateMachineEventResult implements StateMachineEventResult { @@ -109,13 +135,15 @@ static class DefaultStateMachineEventResult implements StateMachineEventRe private final Message message; private final ResultType resultType; private Mono complete; + private Throwable denialCause; DefaultStateMachineEventResult(Region region, Message message, ResultType resultType, - Mono complete) { + Mono complete, Throwable denialCause) { this.region = region; this.message = message; this.resultType = resultType; this.complete = complete != null ? complete : Mono.empty(); + this.denialCause = denialCause; } @Override @@ -138,6 +166,11 @@ public Mono complete() { return complete; } + @Override + public Optional getDenialCause() { + return Optional.ofNullable(denialCause); + } + @Override public String toString() { return "DefaultStateMachineEventResult [region=" + region + ", message=" + message + ", resultType=" diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java index ec12d739b..7cd072aa3 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java @@ -620,12 +620,12 @@ public void setTransitionConflightPolicy(TransitionConflictPolicy transitionConf private Flux> handleEvent(Message message) { if (hasStateMachineError()) { - return Flux.just(StateMachineEventResult.from(this, message, ResultType.DENIED)); + return Flux.just(StateMachineEventResult.from(this, message, ResultType.DENIED, currentError.getCause())); } return Mono.just(message) .map(m -> getStateMachineInterceptors().preEvent(m, this)) .flatMapMany(m -> acceptEvent(m)) - .onErrorResume(error -> Flux.just(StateMachineEventResult.from(this, message, ResultType.DENIED))) + .onErrorResume(error -> Flux.just(StateMachineEventResult.from(this, message, ResultType.DENIED, error.getCause()))) .doOnNext(notifyOnDenied()); } @@ -668,7 +668,7 @@ private Flux> acceptEvent(Message message) { })) .onErrorResume(t -> { return Mono.defer(() -> { - return Mono.just(StateMachineEventResult.from(this, message, ResultType.DENIED)); + return Mono.just(StateMachineEventResult.from(this, message, ResultType.DENIED, t.getCause())); }); }); } else { diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/AbstractSecurityTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/AbstractSecurityTests.java index 718631f5f..eda5574a7 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/AbstractSecurityTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/AbstractSecurityTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,10 +18,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeAll; import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeResultAsDenied; +import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeResultAsDeniedWithAccessDeniedException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import org.springframework.security.access.AccessDeniedException; import org.springframework.statemachine.AbstractStateMachineTests; import org.springframework.statemachine.StateMachine; import org.springframework.statemachine.config.StateMachineBuilder; @@ -53,7 +55,7 @@ protected static void assertTransitionDenied(StateMachine machin assertThat(machine.getState().getIds()).containsOnly(States.S0); listener.reset(1); - doSendEventAndConsumeAll(machine, Events.A); + doSendEventAndConsumeResultAsDeniedWithAccessDeniedException(machine, Events.A); assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS)).isFalse(); assertThat(listener.stateChangedCount).isZero(); assertThat(machine.getState().getIds()).containsOnly(States.S0); diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/EventSecurityTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/EventSecurityTests.java index def627da2..68b35196e 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/EventSecurityTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/security/EventSecurityTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ public class EventSecurityTests extends AbstractSecurityTests { public void testNoSecurityContext() throws Exception { TestListener listener = new TestListener(); StateMachine machine = buildMachine(listener, "ROLE_ANONYMOUS", ComparisonType.ANY, null); - assertTransitionDeniedResultAsDenied(machine, listener); + assertTransitionDenied(machine, listener); } @Test diff --git a/spring-statemachine-core/src/testFixtures/java/org/springframework/statemachine/TestUtils.java b/spring-statemachine-core/src/testFixtures/java/org/springframework/statemachine/TestUtils.java index ef7b1c203..586360faf 100644 --- a/spring-statemachine-core/src/testFixtures/java/org/springframework/statemachine/TestUtils.java +++ b/spring-statemachine-core/src/testFixtures/java/org/springframework/statemachine/TestUtils.java @@ -28,6 +28,7 @@ import org.springframework.beans.factory.BeanFactory; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.access.AccessDeniedException; import org.springframework.statemachine.StateMachineEventResult.ResultType; import org.springframework.statemachine.action.Action; import org.springframework.statemachine.config.StateMachineFactory; @@ -125,6 +126,15 @@ public static void doSendEventAndConsumeResultAsDenied(StateMachine .verifyComplete(); } + public static void doSendEventAndConsumeResultAsDeniedWithAccessDeniedException(StateMachine stateMachine, E event) { + StepVerifier.create(stateMachine.sendEvent(eventAsMono(event))) + .consumeNextWith(result -> { + assertThat(result.getResultType()).isEqualTo(ResultType.DENIED); + assertThat(result.getDenialCause().map(t -> t instanceof AccessDeniedException).orElse(false)).isTrue(); + }) + .verifyComplete(); + } + public static void doSendEventAndConsumeResultAsDenied(StateMachine stateMachine, Message event) { StepVerifier.create(stateMachine.sendEvent(eventAsMono(event))) .consumeNextWith(result -> {