From c9d9848f9483119eb7d535b351914f283a0fa9c9 Mon Sep 17 00:00:00 2001 From: Hao Date: Fri, 14 Feb 2025 00:08:50 +0800 Subject: [PATCH] Ensure ID Token is updated after refresh token --- .../OAuth2ClientConfiguration.java | 16 ++++- .../oauth2/client/OAuth2LoginConfigurer.java | 5 ++ ...Auth2AuthorizedClientManagerRegistrar.java | 7 +++ ...OAuth2AuthorizedClientProviderBuilder.java | 17 ++++++ ...shTokenOAuth2AuthorizedClientProvider.java | 26 +++++++- .../event/OAuth2TokenRefreshedEvent.java | 47 ++++++++++++++ ...thorizationCodeAuthenticationProvider.java | 2 +- .../RefreshOidcIdTokenHandler.java | 61 +++++++++++++++++++ ...enOAuth2AuthorizedClientProviderTests.java | 53 ++++++++++++++++ 9 files changed, 229 insertions(+), 5 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 13c9a1b3c07..55de62810d5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -34,6 +34,9 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -160,7 +163,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() { * @since 6.2.0 */ static final class OAuth2AuthorizedClientManagerRegistrar - implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware { static final String BEAN_NAME = "authorizedClientManagerRegistrar"; @@ -179,6 +182,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator(); + private ApplicationEventPublisher eventPublisher; + private ListableBeanFactory beanFactory; @Override @@ -302,6 +307,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } + if (this.eventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher); + } + return authorizedClientProvider; } @@ -423,6 +432,11 @@ private T getBeanOfType(ResolvableType resolvableType) { return objectProvider.getIfAvailable(); } + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.eventPublisher = applicationContext; + } + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 4c53b3293d0..68cc0dd0bf2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -56,6 +56,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider; +import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler; import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; @@ -393,6 +394,10 @@ public void init(B http) throws Exception { oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); + + RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler( + oidcAuthorizationCodeAuthenticationProvider); + registerDelegateApplicationListener(refreshOidcIdTokenHandler); } else { http.authenticationProvider(new OidcAuthenticationRequestChecker()); diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java index 669d6f7f67f..d2252435f7b 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java @@ -34,6 +34,7 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.annotation.AnnotationBeanNameGenerator; import org.springframework.core.ResolvableType; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; @@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider( authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient); } + ApplicationEventPublisher applicationEventPublisher = getBeanOfType( + ResolvableType.forClass(ApplicationEventPublisher.class)); + if (applicationEventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher); + } + return authorizedClientProvider; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java index c0c8bee93ee..bcd130063e6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.function.Consumer; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; @@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder { private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ApplicationEventPublisher eventPublisher; + private Duration clockSkew; private Clock clock; @@ -379,6 +382,17 @@ public RefreshTokenGrantBuilder accessTokenResponseClient( return this; } + /** + * Sets the {@link ApplicationEventPublisher} used when an access token is + * refreshed. + * @param eventPublisher the {@link ApplicationEventPublisher} + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + return this; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the access * token expiry. An access token is considered expired if @@ -414,6 +428,9 @@ public OAuth2AuthorizedClientProvider build() { if (this.accessTokenResponseClient != null) { authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); } + if (this.eventPublisher != null) { + authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher); + } if (this.clockSkew != null) { authorizedClientProvider.setClockSkew(this.clockSkew); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 410a33fda18..17dc2ad16b9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -24,10 +24,13 @@ import java.util.HashSet; import java.util.Set; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Token; @@ -43,10 +46,13 @@ * @see OAuth2AuthorizedClientProvider * @see DefaultRefreshTokenTokenResponseClient */ -public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { +public final class RefreshTokenOAuth2AuthorizedClientProvider + implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware { private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private ApplicationEventPublisher eventPublisher; + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -91,8 +97,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest); - return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), - context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + + OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient( + authorizedClient.getClientRegistration(), context.getPrincipal().getName(), + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + + if (this.eventPublisher != null) { + this.eventPublisher + .publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse)); + } + + return updatedOAuth2AuthorizedClient; } private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient, @@ -149,4 +164,9 @@ public void setClock(Clock clock) { this.clock = clock; } + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + this.eventPublisher = applicationEventPublisher; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java new file mode 100644 index 00000000000..f92091d4cd7 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.event; + +import org.springframework.context.ApplicationEvent; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +/** + * An event that is published when an OAuth2 access token is refreshed. + */ +public class OAuth2TokenRefreshedEvent extends ApplicationEvent { + + private final OAuth2AuthorizedClient authorizedClient; + + private final OAuth2AccessTokenResponse accessTokenResponse; + + public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient, + OAuth2AccessTokenResponse accessTokenResponse) { + super(source); + this.authorizedClient = authorizedClient; + this.accessTokenResponse = accessTokenResponse; + } + + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + public OAuth2AccessTokenResponse getAccessTokenResponse() { + return this.accessTokenResponse; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java index 64cfba6816a..30ec3945ee5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java @@ -232,7 +232,7 @@ public boolean supports(Class authentication) { return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication); } - private OidcIdToken createOidcToken(ClientRegistration clientRegistration, + protected OidcIdToken createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) { JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java new file mode 100644 index 00000000000..304c556babe --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication; + +import org.springframework.context.ApplicationListener; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; + +/** + * An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s + */ +public class RefreshOidcIdTokenHandler implements ApplicationListener { + + private final OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider; + + public RefreshOidcIdTokenHandler( + OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider) { + this.oidcAuthorizationCodeAuthenticationProvider = oidcAuthorizationCodeAuthenticationProvider; + } + + @Override + public void onApplicationEvent(OAuth2TokenRefreshedEvent event) { + OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient(); + OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse(); + OidcIdToken refreshedOidcToken = this.oidcAuthorizationCodeAuthenticationProvider + .createOidcToken(authorizedClient.getClientRegistration(), accessTokenResponse); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication instanceof OAuth2AuthenticationToken oauth2AuthenticationToken) { + if (authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser) { + OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken, + defaultOidcUser.getUserInfo(), StandardClaimNames.SUB); + SecurityContextHolder.getContext() + .setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + oauth2AuthenticationToken.getAuthorizedClientRegistrationId())); + } + } + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 86ae003eff2..dc4c4232004 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -25,10 +25,12 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -251,4 +253,55 @@ public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllega + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } + @Test + public void shouldPublishEventWhenTokenRefreshed() { + OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher(); + this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + this.authorizedClientProvider.authorize(authorizationContext); + assertThat(eventPublisher.flag).isTrue(); + } + + @Test + public void shouldNotPublishEventWhenTokenNotRefreshed() { + OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher(); + this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + this.authorizedClientProvider.authorize(authorizationContext); + assertThat(eventPublisher.flag).isFalse(); + } + + private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher { + + Boolean flag = false; + + @Override + public void publishEvent(Object event) { + if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) { + this.flag = true; + } + } + + } + }