Skip to content

Commit e7ab8b1

Browse files
committed
Move logic for populating Token Exchange claims
Issue spring-projectsgh-60
1 parent 7ef544f commit e7ab8b1

File tree

5 files changed

+178
-163
lines changed

5 files changed

+178
-163
lines changed

Diff for: oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/DelegatingOAuth2TokenCustomizer.java

-46
This file was deleted.

Diff for: oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ConfigurerUtils.java

+2-30
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
*/
1616
package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
1717

18-
import java.util.ArrayList;
19-
import java.util.List;
2018
import java.util.Map;
2119

2220
import com.nimbusds.jose.jwk.source.JWKSource;
@@ -172,39 +170,13 @@ static JWKSource<SecurityContext> getJwkSource(HttpSecurity httpSecurity) {
172170
}
173171

174172
private static OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(HttpSecurity httpSecurity) {
175-
OAuth2TokenCustomizer<JwtEncodingContext> defaultTokenCustomizer = OAuth2TokenExchangeTokenCustomizers.jwt();
176173
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class);
177-
OAuth2TokenCustomizer<JwtEncodingContext> userTokenCustomizer = getOptionalBean(httpSecurity, type);
178-
179-
OAuth2TokenCustomizer<JwtEncodingContext> tokenCustomizer;
180-
if (userTokenCustomizer != null) {
181-
List<OAuth2TokenCustomizer<JwtEncodingContext>> tokenCustomizers = new ArrayList<>();
182-
tokenCustomizers.add(defaultTokenCustomizer);
183-
tokenCustomizers.add(userTokenCustomizer);
184-
tokenCustomizer = new DelegatingOAuth2TokenCustomizer<>(tokenCustomizers);
185-
} else {
186-
tokenCustomizer = defaultTokenCustomizer;
187-
}
188-
189-
return tokenCustomizer;
174+
return getOptionalBean(httpSecurity, type);
190175
}
191176

192177
private static OAuth2TokenCustomizer<OAuth2TokenClaimsContext> getAccessTokenCustomizer(HttpSecurity httpSecurity) {
193-
OAuth2TokenCustomizer<OAuth2TokenClaimsContext> defaultTokenCustomizer = OAuth2TokenExchangeTokenCustomizers.accessToken();
194178
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, OAuth2TokenClaimsContext.class);
195-
OAuth2TokenCustomizer<OAuth2TokenClaimsContext> userTokenCustomizer = getOptionalBean(httpSecurity, type);
196-
197-
OAuth2TokenCustomizer<OAuth2TokenClaimsContext> tokenCustomizer;
198-
if (userTokenCustomizer != null) {
199-
List<OAuth2TokenCustomizer<OAuth2TokenClaimsContext>> tokenCustomizers = new ArrayList<>();
200-
tokenCustomizers.add(defaultTokenCustomizer);
201-
tokenCustomizers.add(userTokenCustomizer);
202-
tokenCustomizer = new DelegatingOAuth2TokenCustomizer<>(tokenCustomizers);
203-
} else {
204-
tokenCustomizer = defaultTokenCustomizer;
205-
}
206-
207-
return tokenCustomizer;
179+
return getOptionalBean(httpSecurity, type);
208180
}
209181

210182
static AuthorizationServerSettings getAuthorizationServerSettings(HttpSecurity httpSecurity) {

Diff for: oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenExchangeTokenCustomizers.java

-84
This file was deleted.

Diff for: oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/DefaultOAuth2TokenClaimsConsumer.java

+45-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
import java.security.MessageDigest;
1919
import java.security.cert.X509Certificate;
20+
import java.util.ArrayList;
2021
import java.util.Base64;
22+
import java.util.Collections;
2123
import java.util.HashMap;
24+
import java.util.LinkedHashMap;
25+
import java.util.List;
2226
import java.util.Map;
2327
import java.util.function.Consumer;
2428

@@ -28,6 +32,10 @@
2832
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
2933
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
3034
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
35+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeActor;
36+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeAuthenticationToken;
37+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeCompositeAuthenticationToken;
38+
import org.springframework.util.CollectionUtils;
3139

3240
/**
3341
* @author Joe Grandja
@@ -46,10 +54,13 @@ final class DefaultOAuth2TokenClaimsConsumer implements Consumer<Map<String, Obj
4654

4755
@Override
4856
public void accept(Map<String, Object> claims) {
57+
if (!OAuth2TokenType.ACCESS_TOKEN.equals(this.context.getTokenType()) ||
58+
this.context.getAuthorizationGrant() == null) {
59+
return;
60+
}
61+
4962
// Add 'cnf' claim for Mutual-TLS Client Certificate-Bound Access Tokens
50-
if (OAuth2TokenType.ACCESS_TOKEN.equals(this.context.getTokenType()) &&
51-
this.context.getAuthorizationGrant() != null &&
52-
this.context.getAuthorizationGrant().getPrincipal() instanceof OAuth2ClientAuthenticationToken clientAuthentication) {
63+
if (this.context.getAuthorizationGrant().getPrincipal() instanceof OAuth2ClientAuthenticationToken clientAuthentication) {
5364

5465
if ((TLS_CLIENT_AUTH_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod()) ||
5566
SELF_SIGNED_TLS_CLIENT_AUTH_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) &&
@@ -68,6 +79,31 @@ public void accept(Map<String, Object> claims) {
6879
}
6980
}
7081
}
82+
83+
// Add claims for Token Exchange Grant
84+
if (this.context.getAuthorizationGrant() instanceof OAuth2TokenExchangeAuthenticationToken tokenExchangeAuthentication) {
85+
86+
// Append audience value(s) from request to 'aud' claim
87+
if (!CollectionUtils.isEmpty(tokenExchangeAuthentication.getAudiences())) {
88+
List<String> audiences = getAudienceClaim(claims);
89+
audiences.addAll(tokenExchangeAuthentication.getAudiences());
90+
claims.put(OAuth2TokenClaimNames.AUD, audiences);
91+
}
92+
93+
// Add 'act' claim for delegation. If more than one actor is present,
94+
// we create a chain of delegation by nesting "act" claims.
95+
if (this.context.getPrincipal() instanceof OAuth2TokenExchangeCompositeAuthenticationToken compositeAuthenticationToken) {
96+
Map<String, Object> currentClaims = claims;
97+
for (OAuth2TokenExchangeActor actor : compositeAuthenticationToken.getActors()) {
98+
Map<String, Object> actorClaims = actor.getClaims();
99+
Map<String, Object> actClaim = new LinkedHashMap<>();
100+
actClaim.put(OAuth2TokenClaimNames.ISS, actorClaims.get(OAuth2TokenClaimNames.ISS));
101+
actClaim.put(OAuth2TokenClaimNames.SUB, actorClaims.get(OAuth2TokenClaimNames.SUB));
102+
currentClaims.put("act", Collections.unmodifiableMap(actClaim));
103+
currentClaims = actClaim;
104+
}
105+
}
106+
}
71107
}
72108

73109
private static String computeSHA256Thumbprint(X509Certificate x509Certificate) throws Exception {
@@ -76,4 +112,10 @@ private static String computeSHA256Thumbprint(X509Certificate x509Certificate) t
76112
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
77113
}
78114

115+
@SuppressWarnings("unchecked")
116+
private static List<String> getAudienceClaim(Map<String, Object> claims) {
117+
List<String> audiences = (List<String>) claims.getOrDefault(OAuth2TokenClaimNames.AUD, Collections.emptyList());
118+
return new ArrayList<>(audiences);
119+
}
120+
79121
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright 2020-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.token;
17+
18+
import java.util.Collections;
19+
import java.util.LinkedHashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Set;
23+
import java.util.function.Consumer;
24+
25+
import org.junit.jupiter.api.BeforeEach;
26+
import org.junit.jupiter.api.Test;
27+
28+
import org.springframework.security.authentication.TestingAuthenticationToken;
29+
import org.springframework.security.core.Authentication;
30+
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
31+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeActor;
32+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeAuthenticationToken;
33+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeCompositeAuthenticationToken;
34+
35+
import static org.assertj.core.api.Assertions.assertThat;
36+
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.when;
38+
39+
/**
40+
* Tests for {@link DefaultOAuth2TokenClaimsConsumer}.
41+
*
42+
* @author Steve Riesenberg
43+
*/
44+
public class DefaultOAuth2TokenClaimsConsumerTests {
45+
46+
private OAuth2TokenContext tokenContext;
47+
48+
private Consumer<Map<String, Object>> consumer;
49+
50+
@BeforeEach
51+
public void setUp() {
52+
this.tokenContext = mock(OAuth2TokenContext.class);
53+
this.consumer = new DefaultOAuth2TokenClaimsConsumer(this.tokenContext);
54+
}
55+
56+
@Test
57+
public void acceptWhenTokenTypeIsRefreshTokenThenNoClaimsAdded() {
58+
when(this.tokenContext.getTokenType()).thenReturn(OAuth2TokenType.REFRESH_TOKEN);
59+
Map<String, Object> claims = new LinkedHashMap<>();
60+
this.consumer.accept(claims);
61+
assertThat(claims).isEmpty();
62+
}
63+
64+
@Test
65+
public void acceptWhenAuthorizationGrantIsNullThenNoClaimsAdded() {
66+
when(this.tokenContext.getTokenType()).thenReturn(OAuth2TokenType.ACCESS_TOKEN);
67+
when(this.tokenContext.getAuthorizationGrant()).thenReturn(null);
68+
Map<String, Object> claims = new LinkedHashMap<>();
69+
this.consumer.accept(claims);
70+
assertThat(claims).isEmpty();
71+
}
72+
73+
@Test
74+
public void acceptWhenTokenExchangeAndAudiencesEmptyThenNoClaimsAdded() {
75+
OAuth2TokenExchangeAuthenticationToken tokenExchangeAuthentication = mock(
76+
OAuth2TokenExchangeAuthenticationToken.class);
77+
when(tokenExchangeAuthentication.getAudiences()).thenReturn(Collections.emptySet());
78+
when(this.tokenContext.getTokenType()).thenReturn(OAuth2TokenType.ACCESS_TOKEN);
79+
when(this.tokenContext.getAuthorizationGrant()).thenReturn(tokenExchangeAuthentication);
80+
Map<String, Object> claims = new LinkedHashMap<>();
81+
this.consumer.accept(claims);
82+
assertThat(claims).isEmpty();
83+
}
84+
85+
@Test
86+
public void acceptWhenTokenExchangeGrantAndAudiencesThenAudClaimAppended() {
87+
OAuth2TokenExchangeAuthenticationToken tokenExchangeAuthentication = mock(
88+
OAuth2TokenExchangeAuthenticationToken.class);
89+
when(tokenExchangeAuthentication.getAudiences()).thenReturn(Set.of("audience1", "audience2"));
90+
when(this.tokenContext.getTokenType()).thenReturn(OAuth2TokenType.ACCESS_TOKEN);
91+
when(this.tokenContext.getAuthorizationGrant()).thenReturn(tokenExchangeAuthentication);
92+
Map<String, Object> claims = new LinkedHashMap<>();
93+
claims.put(OAuth2TokenClaimNames.AUD, List.of("client1"));
94+
this.consumer.accept(claims);
95+
assertThat(claims).hasSize(1);
96+
assertThat(claims.get(OAuth2TokenClaimNames.AUD)).isNotNull();
97+
@SuppressWarnings("unchecked")
98+
List<String> audiences = (List<String>) claims.get(OAuth2TokenClaimNames.AUD);
99+
assertThat(audiences).containsExactly("client1", "audience1", "audience2");
100+
}
101+
102+
@Test
103+
public void acceptWhenTokenExchangeGrantAndDelegationThenActClaimAdded() {
104+
OAuth2TokenExchangeAuthenticationToken tokenExchangeAuthentication = mock(
105+
OAuth2TokenExchangeAuthenticationToken.class);
106+
when(tokenExchangeAuthentication.getAudiences()).thenReturn(Collections.emptySet());
107+
when(this.tokenContext.getTokenType()).thenReturn(OAuth2TokenType.ACCESS_TOKEN);
108+
when(this.tokenContext.getAuthorizationGrant()).thenReturn(tokenExchangeAuthentication);
109+
Authentication subject = new TestingAuthenticationToken("subject", null);
110+
OAuth2TokenExchangeActor actor1 = new OAuth2TokenExchangeActor(Map.of(OAuth2TokenClaimNames.ISS, "issuer1",
111+
OAuth2TokenClaimNames.SUB, "actor1"));
112+
OAuth2TokenExchangeActor actor2 = new OAuth2TokenExchangeActor(Map.of(OAuth2TokenClaimNames.ISS, "issuer2",
113+
OAuth2TokenClaimNames.SUB, "actor2"));
114+
OAuth2TokenExchangeCompositeAuthenticationToken principal = new OAuth2TokenExchangeCompositeAuthenticationToken(
115+
subject, List.of(actor1, actor2));
116+
when(this.tokenContext.getPrincipal()).thenReturn(principal);
117+
Map<String, Object> claims = new LinkedHashMap<>();
118+
this.consumer.accept(claims);
119+
assertThat(claims).hasSize(1);
120+
assertThat(claims.get("act")).isNotNull();
121+
@SuppressWarnings("unchecked")
122+
Map<String, Object> actClaim1 = (Map<String, Object>) claims.get("act");
123+
assertThat(actClaim1.get(OAuth2TokenClaimNames.ISS)).isEqualTo("issuer1");
124+
assertThat(actClaim1.get(OAuth2TokenClaimNames.SUB)).isEqualTo("actor1");
125+
@SuppressWarnings("unchecked")
126+
Map<String, Object> actClaim2 = (Map<String, Object>) actClaim1.get("act");
127+
assertThat(actClaim2.get(OAuth2TokenClaimNames.ISS)).isEqualTo("issuer2");
128+
assertThat(actClaim2.get(OAuth2TokenClaimNames.SUB)).isEqualTo("actor2");
129+
}
130+
131+
}

0 commit comments

Comments
 (0)