Skip to content

Commit 577d392

Browse files
committed
Provide more flexibility on when to display consent page
1 parent d151568 commit 577d392

File tree

3 files changed

+198
-22
lines changed

3 files changed

+198
-22
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java

+60
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.util.function.Consumer;
2222

2323
import org.springframework.lang.Nullable;
24+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
25+
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
2426
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
2527
import org.springframework.util.Assert;
2628

@@ -63,6 +65,27 @@ public RegisteredClient getRegisteredClient() {
6365
return get(RegisteredClient.class);
6466
}
6567

68+
/**
69+
* Returns the {@link OAuth2AuthorizationRequest oauth2 authorization request}.
70+
*
71+
* @return the {@link OAuth2AuthorizationRequest}
72+
*/
73+
@Nullable
74+
public OAuth2AuthorizationRequest getOAuth2AuthorizationRequest() {
75+
return get(OAuth2AuthorizationRequest.class);
76+
}
77+
78+
/**
79+
* Returns the {@link OAuth2AuthorizationConsent oauth2 authorization consent}.
80+
*
81+
* @return the {@link OAuth2AuthorizationConsent}
82+
*/
83+
@Nullable
84+
public OAuth2AuthorizationConsent getOAuth2AuthorizationConsent() {
85+
return get(OAuth2AuthorizationConsent.class);
86+
}
87+
88+
6689
/**
6790
* Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationCodeRequestAuthenticationToken}.
6891
*
@@ -78,6 +101,21 @@ public static Builder with(OAuth2AuthorizationCodeRequestAuthenticationToken aut
78101
*/
79102
public static final class Builder extends AbstractBuilder<OAuth2AuthorizationCodeRequestAuthenticationContext, Builder> {
80103

104+
/**
105+
* Associates an attribute.
106+
*
107+
* @param key the key for the attribute
108+
* @param value the value of the attribute
109+
* @return the {@link Builder} for further configuration
110+
* @since 1.3.0
111+
*/
112+
@Override
113+
public Builder put(Object key, Object value) {
114+
Assert.notNull(key, "key cannot be null");
115+
getContext().put(key, value);
116+
return getThis();
117+
}
118+
81119
private Builder(OAuth2AuthorizationCodeRequestAuthenticationToken authentication) {
82120
super(authentication);
83121
}
@@ -92,6 +130,28 @@ public Builder registeredClient(RegisteredClient registeredClient) {
92130
return put(RegisteredClient.class, registeredClient);
93131
}
94132

133+
/**
134+
* Sets the {@link OAuth2AuthorizationRequest oauth2 authorization request}.
135+
*
136+
* @param authorizationRequest the {@link OAuth2AuthorizationRequest}
137+
* @return the {@link Builder} for further configuration
138+
* @since 1.3.0
139+
*/
140+
public Builder authorizationRequest(OAuth2AuthorizationRequest authorizationRequest) {
141+
return put(OAuth2AuthorizationRequest.class, authorizationRequest);
142+
}
143+
144+
/**
145+
* Sets the {@link OAuth2AuthorizationConsent oauth2 authorization consent}.
146+
*
147+
* @param authorizationConsent the {@link OAuth2AuthorizationConsent}
148+
* @return the {@link Builder} for further configuration
149+
* @since 1.3.0
150+
*/
151+
public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) {
152+
return put(OAuth2AuthorizationConsent.class, authorizationConsent);
153+
}
154+
95155
/**
96156
* Builds a new {@link OAuth2AuthorizationCodeRequestAuthenticationContext}.
97157
*

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

+53-22
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Base64;
2020
import java.util.Set;
2121
import java.util.function.Consumer;
22+
import java.util.function.Predicate;
2223

2324
import org.apache.commons.logging.Log;
2425
import org.apache.commons.logging.LogFactory;
@@ -80,6 +81,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
8081
private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator();
8182
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator =
8283
new OAuth2AuthorizationCodeRequestAuthenticationValidator();
84+
private Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent;
8385

8486
/**
8587
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters.
@@ -96,6 +98,7 @@ public OAuth2AuthorizationCodeRequestAuthenticationProvider(RegisteredClientRepo
9698
this.registeredClientRepository = registeredClientRepository;
9799
this.authorizationService = authorizationService;
98100
this.authorizationConsentService = authorizationConsentService;
101+
this.requiresAuthorizationConsent = this::requireAuthorizationConsent;
99102
}
100103

101104
@Override
@@ -171,7 +174,14 @@ public Authentication authenticate(Authentication authentication) throws Authent
171174
OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService.findById(
172175
registeredClient.getId(), principal.getName());
173176

174-
if (requireAuthorizationConsent(registeredClient, authorizationRequest, currentAuthorizationConsent)) {
177+
OAuth2AuthorizationCodeRequestAuthenticationContext contextWithAuthorizationRequestAndAuthorizationConsent =
178+
OAuth2AuthorizationCodeRequestAuthenticationContext.with(authorizationCodeRequestAuthentication)
179+
.registeredClient(registeredClient)
180+
.authorizationConsent(currentAuthorizationConsent)
181+
.authorizationRequest(authorizationRequest)
182+
.build();
183+
184+
if (requiresAuthorizationConsent.test(contextWithAuthorizationRequestAndAuthorizationConsent)) {
175185
String state = DEFAULT_STATE_GENERATOR.generateKey();
176186
OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest)
177187
.attribute(OAuth2ParameterNames.STATE, state)
@@ -264,7 +274,48 @@ public void setAuthenticationValidator(Consumer<OAuth2AuthorizationCodeRequestAu
264274
this.authenticationValidator = authenticationValidator;
265275
}
266276

267-
private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal,
277+
/**
278+
* Sets the {@link Predicate} used to determine if authorization consent is required.
279+
*
280+
* <p>
281+
* The {@link OAuth2AuthorizationCodeRequestAuthenticationContext} gives the predicate access to the {@link OAuth2AuthorizationCodeRequestAuthenticationToken},
282+
* as well as, the following context attributes:
283+
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getRegisteredClient()} containing {@link RegisteredClient} used to make the request.
284+
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationRequest()} containing {@link OAuth2AuthorizationRequest}.
285+
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationConsent()} containing {@link OAuth2AuthorizationConsent} granted in the request.
286+
*
287+
* @param requiresAuthorizationConsent the {@link Predicate} that determines if authorization consent is required.
288+
* @since 1.3.0
289+
*/
290+
public void setRequiresAuthorizationConsent(Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent) {
291+
Assert.notNull(requiresAuthorizationConsent, "requiresAuthorizationConsent cannot be null");
292+
this.requiresAuthorizationConsent = requiresAuthorizationConsent;
293+
}
294+
295+
private boolean requireAuthorizationConsent(OAuth2AuthorizationCodeRequestAuthenticationContext context) {
296+
RegisteredClient registeredClient = context.getRegisteredClient();
297+
if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) {
298+
return false;
299+
}
300+
301+
OAuth2AuthorizationRequest authorizationRequest = context.getOAuth2AuthorizationRequest();
302+
// 'openid' scope does not require consent
303+
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) &&
304+
authorizationRequest.getScopes().size() == 1) {
305+
return false;
306+
}
307+
308+
OAuth2AuthorizationConsent authorizationConsent = context.getOAuth2AuthorizationConsent();
309+
if (authorizationConsent != null &&
310+
authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) {
311+
return false;
312+
}
313+
314+
return true;
315+
}
316+
317+
private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient,
318+
Authentication principal,
268319
OAuth2AuthorizationRequest authorizationRequest) {
269320
return OAuth2Authorization.withRegisteredClient(registeredClient)
270321
.principalName(principal.getName())
@@ -295,26 +346,6 @@ private static OAuth2TokenContext createAuthorizationCodeTokenContext(
295346
return tokenContextBuilder.build();
296347
}
297348

298-
private static boolean requireAuthorizationConsent(RegisteredClient registeredClient,
299-
OAuth2AuthorizationRequest authorizationRequest, OAuth2AuthorizationConsent authorizationConsent) {
300-
301-
if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) {
302-
return false;
303-
}
304-
// 'openid' scope does not require consent
305-
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) &&
306-
authorizationRequest.getScopes().size() == 1) {
307-
return false;
308-
}
309-
310-
if (authorizationConsent != null &&
311-
authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) {
312-
return false;
313-
}
314-
315-
return true;
316-
}
317-
318349
private static boolean isPrincipalAuthenticated(Authentication principal) {
319350
return principal != null &&
320351
!AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

+85
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Map;
2222
import java.util.Set;
2323
import java.util.function.Consumer;
24+
import java.util.function.Predicate;
2425

2526
import org.junit.jupiter.api.BeforeEach;
2627
import org.junit.jupiter.api.Test;
@@ -72,6 +73,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
7273
private OAuth2AuthorizationConsentService authorizationConsentService;
7374
private OAuth2AuthorizationCodeRequestAuthenticationProvider authenticationProvider;
7475
private TestingAuthenticationToken principal;
76+
private Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent;
7577

7678
@BeforeEach
7779
public void setUp() {
@@ -129,6 +131,13 @@ public void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException(
129131
.hasMessage("authenticationValidator cannot be null");
130132
}
131133

134+
@Test
135+
public void setRequiresAuthorizationConsentWhenNullThenThrowIllegalArgumentException() {
136+
assertThatThrownBy(() -> this.authenticationProvider.setRequiresAuthorizationConsent(null))
137+
.isInstanceOf(IllegalArgumentException.class)
138+
.hasMessage("requiresAuthorizationConsent cannot be null");
139+
}
140+
132141
@Test
133142
public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
134143
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -443,6 +452,82 @@ public void authenticateWhenRequireAuthorizationConsentThenReturnAuthorizationCo
443452
assertThat(authenticationResult.isAuthenticated()).isTrue();
444453
}
445454

455+
@Test
456+
public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateTrueThenReturnAuthorizationConsent() {
457+
this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> true);
458+
459+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
460+
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
461+
.build();
462+
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
463+
.thenReturn(registeredClient);
464+
465+
String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[0];
466+
OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
467+
new OAuth2AuthorizationCodeRequestAuthenticationToken(
468+
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
469+
redirectUri, STATE, registeredClient.getScopes(), null);
470+
471+
OAuth2AuthorizationConsentAuthenticationToken authenticationResult =
472+
(OAuth2AuthorizationConsentAuthenticationToken) this.authenticationProvider.authenticate(authentication);
473+
474+
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
475+
verify(this.authorizationService).save(authorizationCaptor.capture());
476+
OAuth2Authorization authorization = authorizationCaptor.getValue();
477+
478+
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
479+
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
480+
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
481+
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(authentication.getAuthorizationUri());
482+
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
483+
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(authentication.getRedirectUri());
484+
assertThat(authorizationRequest.getScopes()).isEqualTo(authentication.getScopes());
485+
assertThat(authorizationRequest.getState()).isEqualTo(authentication.getState());
486+
assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(authentication.getAdditionalParameters());
487+
488+
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
489+
assertThat(authorization.getPrincipalName()).isEqualTo(this.principal.getName());
490+
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
491+
assertThat(authorization.<Authentication>getAttribute(Principal.class.getName())).isEqualTo(this.principal);
492+
String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
493+
assertThat(state).isNotNull();
494+
assertThat(state).isNotEqualTo(authentication.getState());
495+
496+
assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId());
497+
assertThat(authenticationResult.getPrincipal()).isEqualTo(this.principal);
498+
assertThat(authenticationResult.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri());
499+
assertThat(authenticationResult.getScopes()).isEmpty();
500+
assertThat(authenticationResult.getState()).isEqualTo(state);
501+
assertThat(authenticationResult.isAuthenticated()).isTrue();
502+
}
503+
504+
@Test
505+
public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateFalseThenAuthorizationConsentNotRequired() {
506+
this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> false);
507+
508+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
509+
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
510+
.scopes(scopes -> {
511+
scopes.clear();
512+
scopes.add(OidcScopes.OPENID);
513+
scopes.add(OidcScopes.EMAIL);
514+
})
515+
.build();
516+
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
517+
.thenReturn(registeredClient);
518+
519+
String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[1];
520+
OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
521+
new OAuth2AuthorizationCodeRequestAuthenticationToken(
522+
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
523+
redirectUri, STATE, registeredClient.getScopes(), null);
524+
525+
OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult =
526+
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication);
527+
528+
assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult);
529+
}
530+
446531
@Test
447532
public void authenticateWhenRequireAuthorizationConsentAndOnlyOpenidScopeRequestedThenAuthorizationConsentNotRequired() {
448533
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()

0 commit comments

Comments
 (0)