diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index 3679b7e36e5..aa1535c2054 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -58,6 +58,7 @@ * @author Joe Grandja * @author Rafael Dominguez * @author Mark Heckler + * @author Ivan Golovko * @since 5.2 * @see JwtDecoderFactory * @see ClientRegistration @@ -78,7 +79,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory jwtDecoders = new ConcurrentHashMap<>(); + private final Map jwtDecoders; private Function> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); @@ -88,6 +89,19 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + public OidcIdTokenDecoderFactory() { + this(true); + } + + public OidcIdTokenDecoderFactory(boolean withCache) { + if (withCache) { + this.jwtDecoders = new ConcurrentHashMap<>(); + } + else { + this.jwtDecoders = null; + } + } + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcIdToken}. @@ -135,16 +149,24 @@ public static ClaimTypeConverter createDefaultClaimTypeConverter() { @Override public JwtDecoder createDecoder(ClientRegistration clientRegistration) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { - NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration); - jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); - Converter, Map> claimTypeConverter = this.claimTypeConverterFactory - .apply(clientRegistration); - if (claimTypeConverter != null) { - jwtDecoder.setClaimSetConverter(claimTypeConverter); - } - return jwtDecoder; - }); + if (this.jwtDecoders != null) { + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), + (key) -> createFreshDecoder(clientRegistration)); + } + else { + return createFreshDecoder(clientRegistration); + } + } + + private JwtDecoder createFreshDecoder(ClientRegistration clientRegistration) { + NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration); + jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); + if (claimTypeConverter != null) { + jwtDecoder.setClaimSetConverter(claimTypeConverter); + } + return jwtDecoder; } private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java index 5c066d3bacd..d7ba23b2f1b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java @@ -59,6 +59,7 @@ * @author Rafael Dominguez * @author Mark Heckler * @author Ubaid ur Rehman + * @author Ivan Golovko * @since 5.2 * @see ReactiveJwtDecoderFactory * @see ClientRegistration @@ -80,7 +81,7 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( createDefaultClaimTypeConverters()); - private final Map jwtDecoders = new ConcurrentHashMap<>(); + private final Map jwtDecoders; private Function> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); @@ -90,6 +91,19 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod private Function, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + public ReactiveOidcIdTokenDecoderFactory() { + this(true); + } + + public ReactiveOidcIdTokenDecoderFactory(boolean withCache) { + if (withCache) { + this.jwtDecoders = new ConcurrentHashMap<>(); + } + else { + this.jwtDecoders = null; + } + } + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcIdToken}. @@ -126,16 +140,24 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod @Override public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { - NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration); - jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); - Converter, Map> claimTypeConverter = this.claimTypeConverterFactory - .apply(clientRegistration); - if (claimTypeConverter != null) { - jwtDecoder.setClaimSetConverter(claimTypeConverter); - } - return jwtDecoder; - }); + if (this.jwtDecoders != null) { + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), + (key) -> createFreshDecoder(clientRegistration)); + } + else { + return createFreshDecoder(clientRegistration); + } + } + + private ReactiveJwtDecoder createFreshDecoder(ClientRegistration clientRegistration) { + NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration); + jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); + if (claimTypeConverter != null) { + jwtDecoder.setClaimSetConverter(claimTypeConverter); + } + return jwtDecoder; } private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java index 33663bac650..7a3944d23ce 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -46,6 +47,7 @@ /** * @author Joe Grandja * @author Rafael Dominguez + * @author Ivan Golovko * @since 5.2 */ public class OidcIdTokenDecoderFactoryTests { @@ -177,4 +179,21 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderTwiceWithCaching() { + ClientRegistration clientRegistration = this.registration.build(); + JwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + JwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + assertThat(decoder1).isSameAs(decoder2); + } + + @Test + public void createDecoderTwiceWithoutCaching() { + this.idTokenDecoderFactory = new OidcIdTokenDecoderFactory(false); + ClientRegistration clientRegistration = this.registration.build(); + JwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + JwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + assertThat(decoder1).isNotSameAs(decoder2); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java index 8c5b70ea494..2a4e31aa6dc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -47,6 +48,7 @@ * @author Joe Grandja * @author Rafael Dominguez * @author Ubaid ur Rehman + * @author Ivan Golovko * @since 5.2 */ public class ReactiveOidcIdTokenDecoderFactoryTests { @@ -177,4 +179,21 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderTwiceWithCaching() { + ClientRegistration clientRegistration = this.registration.build(); + ReactiveJwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + ReactiveJwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + assertThat(decoder1).isSameAs(decoder2); + } + + @Test + public void createDecoderTwiceWithoutCaching() { + this.idTokenDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(false); + ClientRegistration clientRegistration = this.registration.build(); + ReactiveJwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + ReactiveJwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + assertThat(decoder1).isNotSameAs(decoder2); + } + }