Skip to content

Commit

Permalink
Ensure ID Token is updated after refresh token
Browse files Browse the repository at this point in the history
Signed-off-by: Hao <[email protected]>
  • Loading branch information
yhao3 committed Feb 13, 2025
1 parent 9c51507 commit 8f7b48e
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -160,7 +163,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
* @since 6.2.0
*/
static final class OAuth2AuthorizedClientManagerRegistrar
implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware {

static final String BEAN_NAME = "authorizedClientManagerRegistrar";

Expand All @@ -179,6 +182,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar

private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();

private ApplicationEventPublisher eventPublisher;

private ListableBeanFactory beanFactory;

@Override
Expand Down Expand Up @@ -302,6 +307,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

if (this.eventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
}

return authorizedClientProvider;
}

Expand Down Expand Up @@ -423,6 +432,11 @@ private <T> T getBeanOfType(ResolvableType resolvableType) {
return objectProvider.getIfAvailable();
}

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.eventPublisher = applicationContext;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler;
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
Expand Down Expand Up @@ -393,6 +394,10 @@ public void init(B http) throws Exception {
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
}
http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));

RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler(
oidcAuthorizationCodeAuthenticationProvider);
registerDelegateApplicationListener(refreshOidcIdTokenHandler);
}
else {
http.authenticationProvider(new OidcAuthenticationRequestChecker());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.core.ResolvableType;
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
Expand Down Expand Up @@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

ApplicationEventPublisher applicationEventPublisher = getBeanOfType(
ResolvableType.forClass(ApplicationEventPublisher.class));
if (applicationEventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher);
}

return authorizedClientProvider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
Expand Down Expand Up @@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder {

private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;

private ApplicationEventPublisher eventPublisher;

private Duration clockSkew;

private Clock clock;
Expand All @@ -379,6 +382,17 @@ public RefreshTokenGrantBuilder accessTokenResponseClient(
return this;
}

/**
* Sets the {@link ApplicationEventPublisher} used when an access token is
* refreshed.
* @param eventPublisher the {@link ApplicationEventPublisher}
* @return the {@link RefreshTokenGrantBuilder}
*/
public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
return this;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the access
* token expiry. An access token is considered expired if
Expand Down Expand Up @@ -414,6 +428,9 @@ public OAuth2AuthorizedClientProvider build() {
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.eventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import java.util.HashSet;
import java.util.Set;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
Expand All @@ -43,10 +46,13 @@
* @see OAuth2AuthorizedClientProvider
* @see DefaultRefreshTokenTokenResponseClient
*/
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
public final class RefreshTokenOAuth2AuthorizedClientProvider
implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware {

private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();

private ApplicationEventPublisher eventPublisher;

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();
Expand Down Expand Up @@ -91,8 +97,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(),
authorizedClient.getRefreshToken(), scopes);
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest);
return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(),
context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());

OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient(
authorizedClient.getClientRegistration(), context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());

if (this.eventPublisher != null) {
this.eventPublisher
.publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse));
}

return updatedOAuth2AuthorizedClient;
}

private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient,
Expand Down Expand Up @@ -149,4 +164,9 @@ public void setClock(Clock clock) {
this.clock = clock;
}

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.event;

import org.springframework.context.ApplicationEvent;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;

/**
* An event that is published when an OAuth2 access token is refreshed.
*/
public class OAuth2TokenRefreshedEvent extends ApplicationEvent {

private final OAuth2AuthorizedClient authorizedClient;

private final OAuth2AccessTokenResponse accessTokenResponse;

public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient,
OAuth2AccessTokenResponse accessTokenResponse) {
super(source);
this.authorizedClient = authorizedClient;
this.accessTokenResponse = accessTokenResponse;
}

public OAuth2AuthorizedClient getAuthorizedClient() {
return this.authorizedClient;
}

public OAuth2AccessTokenResponse getAccessTokenResponse() {
return this.accessTokenResponse;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public boolean supports(Class<?> authentication) {
return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication);
}

private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
protected OidcIdToken createOidcToken(ClientRegistration clientRegistration,
OAuth2AccessTokenResponse accessTokenResponse) {
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.oidc.authentication;

import org.springframework.context.ApplicationListener;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;

/**
* An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s
*/
public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {

private final OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider;

public RefreshOidcIdTokenHandler(
OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider) {
this.oidcAuthorizationCodeAuthenticationProvider = oidcAuthorizationCodeAuthenticationProvider;
}

@Override
public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();
OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();
OidcIdToken refreshedOidcToken = this.oidcAuthorizationCodeAuthenticationProvider
.createOidcToken(authorizedClient.getClientRegistration(), accessTokenResponse);
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication instanceof OAuth2AuthenticationToken oauth2AuthenticationToken) {
if (authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser) {
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
SecurityContextHolder.getContext()
.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
oauth2AuthenticationToken.getAuthorizedClientRegistrationId()));
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
Expand Down Expand Up @@ -251,4 +253,55 @@ public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllega
+ OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'");
}

@Test
public void shouldPublishEventWhenTokenRefreshed() {
OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse()
.refreshToken("new-refresh-token")
.build();
// @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
this.authorizedClientProvider.authorize(authorizationContext);
assertThat(eventPublisher.flag).isTrue();
}

@Test
public void shouldNotPublishEventWhenTokenNotRefreshed() {
OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
this.authorizedClientProvider.authorize(authorizationContext);
assertThat(eventPublisher.flag).isFalse();
}

private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher {

Boolean flag = false;

@Override
public void publishEvent(Object event) {
if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) {
this.flag = true;
}
}

}

}

0 comments on commit 8f7b48e

Please sign in to comment.