Skip to content

Commit e09cf64

Browse files
authored
Merge pull request #47450 from michalvavrik/feature/make-oidc-required-claims-array
Make OIDC required claims support arrays
2 parents 338bee4 + 496d0f4 commit e09cf64

File tree

16 files changed

+349
-62
lines changed

16 files changed

+349
-62
lines changed

extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -2222,7 +2222,7 @@ public static Token fromAudience(String... audience) {
22222222
* Strings are the only supported types. Use {@linkplain SecurityIdentityAugmentor} to verify claims of other types or
22232223
* complex claims.
22242224
*/
2225-
public Map<String, String> requiredClaims = new HashMap<>();
2225+
public Map<String, Set<String>> requiredClaims = new HashMap<>();
22262226

22272227
/**
22282228
* Expected token type
@@ -2507,11 +2507,11 @@ public void setDecryptionKeyLocation(String decryptionKeyLocation) {
25072507
this.decryptionKeyLocation = Optional.of(decryptionKeyLocation);
25082508
}
25092509

2510-
public Map<String, String> getRequiredClaims() {
2510+
public Map<String, Set<String>> getRequiredClaims() {
25112511
return requiredClaims;
25122512
}
25132513

2514-
public void setRequiredClaims(Map<String, String> requiredClaims) {
2514+
public void setRequiredClaims(Map<String, Set<String>> requiredClaims) {
25152515
this.requiredClaims = requiredClaims;
25162516
}
25172517

@@ -2596,7 +2596,7 @@ public boolean subjectRequired() {
25962596
}
25972597

25982598
@Override
2599-
public Map<String, String> requiredClaims() {
2599+
public Map<String, Set<String>> requiredClaims() {
26002600
return requiredClaims;
26012601
}
26022602

extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java

+74-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.quarkus.oidc.runtime;
22

3+
import static java.util.Objects.requireNonNull;
4+
35
import java.io.Closeable;
46
import java.nio.charset.StandardCharsets;
57
import java.security.Key;
@@ -8,9 +10,11 @@
810
import java.util.Base64;
911
import java.util.List;
1012
import java.util.Map;
13+
import java.util.Set;
1114
import java.util.function.BiFunction;
1215
import java.util.function.Function;
1316

17+
import jakarta.json.JsonArray;
1418
import jakarta.json.JsonObject;
1519

1620
import org.eclipse.microprofile.jwt.Claims;
@@ -76,7 +80,7 @@ public class OidcProvider implements Closeable {
7680
final TokenCustomizer tokenCustomizer;
7781
final String issuer;
7882
final String[] audience;
79-
final Map<String, String> requiredClaims;
83+
final Map<String, Set<String>> requiredClaims;
8084
final Key tokenDecryptionKey;
8185
final AlgorithmConstraints requiredAlgorithmConstraints;
8286

@@ -160,8 +164,9 @@ private String[] checkAudienceProp() {
160164
return audienceProp != null ? audienceProp.toArray(new String[] {}) : null;
161165
}
162166

163-
private Map<String, String> checkRequiredClaimsProp() {
164-
return oidcConfig != null ? oidcConfig.token().requiredClaims() : null;
167+
private Map<String, Set<String>> checkRequiredClaimsProp() {
168+
return oidcConfig != null && !oidcConfig.token().requiredClaims().isEmpty() ? oidcConfig.token().requiredClaims()
169+
: null;
165170
}
166171

167172
public TokenVerificationResult verifySelfSignedJwtToken(String token, Key generatedInternalSignatureKey)
@@ -216,7 +221,7 @@ private TokenVerificationResult verifyJwtTokenInternal(String token,
216221
}
217222

218223
if (nonce != null) {
219-
builder.registerValidator(new CustomClaimsValidator(Map.of(OidcConstants.NONCE, nonce)));
224+
builder.registerValidator(new CustomClaimsValidator(Map.of(OidcConstants.NONCE, Set.of(nonce))));
220225
}
221226

222227
for (Validator customValidator : customValidators) {
@@ -241,7 +246,7 @@ private TokenVerificationResult verifyJwtTokenInternal(String token,
241246
} else {
242247
builder.setSkipDefaultAudienceValidation();
243248
}
244-
if (requiredClaims != null && !requiredClaims.isEmpty()) {
249+
if (requiredClaims != null) {
245250
builder.registerValidator(new CustomClaimsValidator(requiredClaims));
246251
}
247252

@@ -387,22 +392,46 @@ public TokenIntrospection apply(TokenIntrospection introspectionResult, Throwabl
387392
throw new AuthenticationFailedException(ex, tokenMap(token, idToken));
388393
}
389394

390-
if (requiredClaims != null && !requiredClaims.isEmpty()) {
391-
for (Map.Entry<String, String> requiredClaim : requiredClaims.entrySet()) {
392-
String introspectionClaimValue = null;
393-
try {
394-
introspectionClaimValue = introspectionResult.getString(requiredClaim.getKey());
395-
} catch (ClassCastException ex) {
396-
LOG.debugf("Introspection claim %s is not String", requiredClaim.getKey());
395+
if (requiredClaims != null) {
396+
for (Map.Entry<String, Set<String>> requiredClaim : requiredClaims.entrySet()) {
397+
final String requiredClaimName = requiredClaim.getKey();
398+
if (!introspectionResult.contains(requiredClaimName)) {
399+
LOG.debugf("Introspection claim %s is missing", requiredClaimName);
397400
throw new AuthenticationFailedException(tokenMap(token, idToken));
398401
}
399-
if (introspectionClaimValue == null) {
400-
LOG.debugf("Introspection claim %s is missing", requiredClaim.getKey());
402+
final Set<String> requiredClaimValues = requiredClaim.getValue();
403+
if (requiredClaimValues.size() == 1) {
404+
String introspectionClaimValue = null;
405+
try {
406+
introspectionClaimValue = introspectionResult.getString(requiredClaimName);
407+
} catch (ClassCastException ex) {
408+
LOG.debugf("Introspection claim %s is not String", requiredClaimName);
409+
}
410+
String requiredClaimValue = requiredClaimValues.iterator().next();
411+
if (requiredClaimValue.equals(introspectionClaimValue)) {
412+
continue;
413+
}
414+
}
415+
final JsonArray actualClaimValueArray;
416+
try {
417+
actualClaimValueArray = requireNonNull(introspectionResult.getArray(requiredClaimName));
418+
} catch (Exception ignored) {
419+
LOG.debugf("Introspection claim %s is neither string or array", requiredClaimName);
401420
throw new AuthenticationFailedException(tokenMap(token, idToken));
402421
}
403-
if (!introspectionClaimValue.equals(requiredClaim.getValue())) {
422+
requiredClaimValuesLoop: for (String requiredClaimValue : requiredClaimValues) {
423+
for (int i = 0; i < actualClaimValueArray.size(); i++) {
424+
try {
425+
String actualClaimValue = actualClaimValueArray.getString(i);
426+
if (requiredClaimValue.equals(actualClaimValue)) {
427+
continue requiredClaimValuesLoop;
428+
}
429+
} catch (Exception ignored) {
430+
// try next actual claim value
431+
}
432+
}
404433
LOG.debugf("Value of the introspection claim %s does not match required value of %s",
405-
requiredClaim.getKey(), requiredClaim.getValue());
434+
requiredClaimName, requiredClaimValue);
406435
throw new AuthenticationFailedException(tokenMap(token, idToken));
407436
}
408437
}
@@ -416,8 +445,7 @@ public TokenIntrospection apply(TokenIntrospection introspectionResult, Throwabl
416445

417446
private void verifyTokenExpiry(String token, boolean idToken, Long exp) {
418447
if (isTokenExpired(exp)) {
419-
String error = String.format("Token issued to client %s has expired",
420-
oidcConfig.clientId().get());
448+
String error = String.format("Token issued to client %s has expired", oidcConfig.clientId().get());
421449
LOG.debugf(error);
422450
throw new AuthenticationFailedException(
423451
new InvalidJwtException(error,
@@ -436,7 +464,7 @@ private int getLifespanGrace() {
436464
: 0;
437465
}
438466

439-
private static final long now() {
467+
private static long now() {
440468
return System.currentTimeMillis();
441469
}
442470

@@ -624,7 +652,7 @@ public Key resolveKey(JsonWebSignature jws, List<JsonWebStructure> nestingContex
624652
}
625653

626654
private Key initKey(Key generatedInternalSignatureKey) {
627-
String clientSecret = OidcCommonUtils.getClientOrJwtSecret(oidcConfig.credentials);
655+
String clientSecret = OidcCommonUtils.getClientOrJwtSecret(oidcConfig.credentials());
628656
if (clientSecret != null) {
629657
LOG.debug("Verifying internal ID token with a configured client secret");
630658
return KeyUtils.createSecretKeyFromSecret(clientSecret);
@@ -642,11 +670,11 @@ public OidcConfigurationMetadata getMetadata() {
642670
return client == null ? null : client.getMetadata();
643671
}
644672

645-
private static class CustomClaimsValidator implements Validator {
673+
private static final class CustomClaimsValidator implements Validator {
646674

647-
private final Map<String, String> customClaims;
675+
private final Map<String, Set<String>> customClaims;
648676

649-
public CustomClaimsValidator(Map<String, String> customClaims) {
677+
private CustomClaimsValidator(Map<String, Set<String>> customClaims) {
650678
this.customClaims = customClaims;
651679
}
652680

@@ -658,13 +686,29 @@ public String validate(JwtContext jwtContext) throws MalformedClaimException {
658686
if (!claims.hasClaim(claimName)) {
659687
return "claim " + claimName + " is missing";
660688
}
661-
if (!claims.isClaimValueString(claimName)) {
662-
throw new MalformedClaimException("expected claim " + claimName + " to be a string");
663-
}
664-
var claimValue = claims.getStringClaimValue(claimName);
665-
var targetValue = targetClaim.getValue();
666-
if (!claimValue.equals(targetValue)) {
667-
return "claim " + claimName + " does not match expected value of " + targetValue;
689+
Set<String> requiredClaimValues = targetClaim.getValue();
690+
if (claims.isClaimValueString(claimName)) {
691+
if (requiredClaimValues.size() == 1) {
692+
String actualClaimValue = claims.getStringClaimValue(claimName);
693+
String requiredClaimValue = requiredClaimValues.iterator().next();
694+
if (!requiredClaimValue.equals(actualClaimValue)) {
695+
return "claim " + claimName + " does not match expected value of " + requiredClaimValues;
696+
}
697+
} else {
698+
throw new MalformedClaimException("expected claim " + claimName + " must be a list of strings");
699+
}
700+
} else {
701+
if (claims.isClaimValueStringList(claimName)) {
702+
List<String> actualClaimValues = claims.getStringListClaimValue(claimName);
703+
for (String requiredClaimValue : requiredClaimValues) {
704+
if (!actualClaimValues.contains(requiredClaimValue)) {
705+
return "claim " + claimName + " does not match expected value of " + requiredClaimValues;
706+
}
707+
}
708+
} else {
709+
throw new MalformedClaimException(
710+
"expected claim " + claimName + " must be a list of strings or a string");
711+
}
668712
}
669713
}
670714
return null;

extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcTenantConfig.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -1002,13 +1002,14 @@ interface Token {
10021002

10031003
/**
10041004
* A map of required claims and their expected values.
1005-
* For example, `quarkus.oidc.token.required-claims.org_id = org_xyz` would require tokens to have the `org_id` claim to
1006-
* be present and set to `org_xyz`.
1007-
* Strings are the only supported types. Use {@linkplain SecurityIdentityAugmentor} to verify claims of other types or
1008-
* complex claims.
1005+
* For example, `quarkus.oidc.token.required-claims.org_id = org_xyz` would require tokens to have the `org_id`
1006+
* claim to be present and set to `org_xyz`. On the other hand, if it was set to `org_xyz,org_abc`,
1007+
* the `org_id` claim would need to have both `org_xyz` and `org_abc` values.
1008+
* Strings and arrays of strings are currently the only supported types.
1009+
* Use {@linkplain SecurityIdentityAugmentor} to verify claims of other types or complex claims.
10091010
*/
10101011
@ConfigDocMapKey("claim-name")
1011-
Map<String, String> requiredClaims();
1012+
Map<String, Set<@WithConverter(TrimmedStringConverter.class) String>> requiredClaims();
10121013

10131014
/**
10141015
* Expected token type

extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/StaticTenantResolver.java

+37-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.HashMap;
77
import java.util.List;
88
import java.util.Map;
9+
import java.util.Set;
910
import java.util.concurrent.atomic.AtomicBoolean;
1011
import java.util.function.BiFunction;
1112

@@ -19,6 +20,7 @@
1920
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
2021
import io.quarkus.vertx.http.runtime.security.ImmutablePathMatcher;
2122
import io.smallrye.mutiny.Uni;
23+
import io.vertx.core.json.JsonArray;
2224
import io.vertx.core.json.JsonObject;
2325
import io.vertx.ext.web.RoutingContext;
2426

@@ -246,9 +248,41 @@ private static String getTenantId(RoutingContext context, TenantConfigContext te
246248
return null;
247249
}
248250

249-
private static boolean requiredClaimsMatch(Map<String, String> requiredClaims, JsonObject tokenJson) {
250-
for (Map.Entry<String, String> entry : requiredClaims.entrySet()) {
251-
if (!entry.getValue().equals(tokenJson.getString(entry.getKey()))) {
251+
private static boolean requiredClaimsMatch(Map<String, Set<String>> requiredClaims, JsonObject tokenJson) {
252+
for (Map.Entry<String, Set<String>> entry : requiredClaims.entrySet()) {
253+
Set<String> requiredClaimSet = entry.getValue();
254+
String claimName = entry.getKey();
255+
if (requiredClaimSet.size() == 1) {
256+
String actualClaimValueAsStr;
257+
try {
258+
actualClaimValueAsStr = tokenJson.getString(claimName);
259+
} catch (Exception ex) {
260+
actualClaimValueAsStr = null;
261+
}
262+
if (actualClaimValueAsStr != null && requiredClaimSet.contains(actualClaimValueAsStr)) {
263+
continue;
264+
}
265+
}
266+
final JsonArray actualClaimValues;
267+
try {
268+
actualClaimValues = tokenJson.getJsonArray(claimName);
269+
} catch (Exception e) {
270+
return false;
271+
}
272+
if (actualClaimValues == null) {
273+
return false;
274+
}
275+
outer: for (String requiredClaimValue : requiredClaimSet) {
276+
for (int i = 0; i < actualClaimValues.size(); i++) {
277+
try {
278+
String actualClaimValue = actualClaimValues.getString(i);
279+
if (requiredClaimValue.equals(actualClaimValue)) {
280+
continue outer;
281+
}
282+
} catch (Exception ignored) {
283+
// try next actual claim value
284+
}
285+
}
252286
return false;
253287
}
254288
}

extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/builders/TokenConfigBuilder.java

+42-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import java.util.Objects;
1010
import java.util.Optional;
1111
import java.util.OptionalInt;
12+
import java.util.Set;
13+
import java.util.function.Function;
14+
import java.util.stream.Collectors;
1215

1316
import io.quarkus.oidc.OidcTenantConfigBuilder;
1417
import io.quarkus.oidc.runtime.OidcTenantConfig;
@@ -20,7 +23,7 @@
2023
public final class TokenConfigBuilder {
2124

2225
private record TokenImpl(Optional<String> issuer, Optional<List<String>> audience, boolean subjectRequired,
23-
Map<String, String> requiredClaims, Optional<String> tokenType, OptionalInt lifespanGrace,
26+
Map<String, Set<String>> requiredClaims, Optional<String> tokenType, OptionalInt lifespanGrace,
2427
Optional<Duration> age, boolean issuedAtRequired, Optional<String> principalClaim, boolean refreshExpired,
2528
Optional<Duration> refreshTokenTimeSkew, Duration forcedJwkRefreshInterval, Optional<String> header,
2629
String authorizationScheme, Optional<OidcTenantConfig.SignatureAlgorithm> signatureAlgorithm,
@@ -30,7 +33,7 @@ private record TokenImpl(Optional<String> issuer, Optional<List<String>> audienc
3033
}
3134

3235
private final OidcTenantConfigBuilder builder;
33-
private final Map<String, String> requiredClaims = new HashMap<>();
36+
private final Map<String, Set<String>> requiredClaims = new HashMap<>();
3437
private final List<String> audience = new ArrayList<>();
3538
private Optional<String> issuer;
3639
private boolean subjectRequired;
@@ -103,7 +106,19 @@ public OidcTenantConfigBuilder end() {
103106
public TokenConfigBuilder requiredClaims(String requiredClaimName, String requiredClaimValue) {
104107
Objects.requireNonNull(requiredClaimName);
105108
Objects.requireNonNull(requiredClaimValue);
106-
this.requiredClaims.put(requiredClaimName, requiredClaimValue);
109+
this.requiredClaims.put(requiredClaimName, Set.of(requiredClaimValue));
110+
return this;
111+
}
112+
113+
/**
114+
* @param requiredClaimName {@link OidcTenantConfig.Token#requiredClaims()} name
115+
* @param requiredClaimValues {@link OidcTenantConfig.Token#requiredClaims()} value
116+
* @return this builder
117+
*/
118+
public TokenConfigBuilder requiredClaims(String requiredClaimName, Set<String> requiredClaimValues) {
119+
Objects.requireNonNull(requiredClaimName);
120+
Objects.requireNonNull(requiredClaimValues);
121+
this.requiredClaims.put(requiredClaimName, Set.copyOf(requiredClaimValues));
107122
return this;
108123
}
109124

@@ -112,6 +127,30 @@ public TokenConfigBuilder requiredClaims(String requiredClaimName, String requir
112127
* @return this builder
113128
*/
114129
public TokenConfigBuilder requiredClaims(Map<String, String> requiredClaims) {
130+
if (requiredClaims != null) {
131+
return this.setRequiredClaims(requiredClaims
132+
.entrySet()
133+
.stream()
134+
.collect(Collectors.toMap(new Function<Map.Entry<String, String>, String>() {
135+
@Override
136+
public String apply(Map.Entry<String, String> stringStringEntry) {
137+
return stringStringEntry.getKey();
138+
}
139+
}, new Function<Map.Entry<String, String>, Set<String>>() {
140+
@Override
141+
public Set<String> apply(Map.Entry<String, String> e) {
142+
return Set.of(e.getValue());
143+
}
144+
})));
145+
}
146+
return this;
147+
}
148+
149+
/**
150+
* @param requiredClaims {@link OidcTenantConfig.Token#requiredClaims()}
151+
* @return this builder
152+
*/
153+
public TokenConfigBuilder setRequiredClaims(Map<String, Set<String>> requiredClaims) {
115154
if (requiredClaims != null) {
116155
this.requiredClaims.putAll(requiredClaims);
117156
}

0 commit comments

Comments
 (0)