Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure ID Token is updated after refresh token #16589

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -160,7 +163,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
* @since 6.2.0
*/
static final class OAuth2AuthorizedClientManagerRegistrar
implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware {

static final String BEAN_NAME = "authorizedClientManagerRegistrar";

Expand All @@ -179,6 +182,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar

private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();

private ApplicationEventPublisher eventPublisher;

private ListableBeanFactory beanFactory;

@Override
Expand Down Expand Up @@ -302,6 +307,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

if (this.eventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
}

return authorizedClientProvider;
}

Expand Down Expand Up @@ -423,6 +432,11 @@ private <T> T getBeanOfType(ResolvableType resolvableType) {
return objectProvider.getIfAvailable();
}

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.eventPublisher = applicationContext;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler;
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
Expand Down Expand Up @@ -393,6 +394,10 @@ public void init(B http) throws Exception {
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
}
http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));

RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler(
oidcAuthorizationCodeAuthenticationProvider);
registerDelegateApplicationListener(refreshOidcIdTokenHandler);
}
else {
http.authenticationProvider(new OidcAuthenticationRequestChecker());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.core.ResolvableType;
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
Expand Down Expand Up @@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

ApplicationEventPublisher applicationEventPublisher = getBeanOfType(
ResolvableType.forClass(ApplicationEventPublisher.class));
if (applicationEventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher);
}

return authorizedClientProvider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
Expand Down Expand Up @@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder {

private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;

private ApplicationEventPublisher eventPublisher;

private Duration clockSkew;

private Clock clock;
Expand All @@ -379,6 +382,17 @@ public RefreshTokenGrantBuilder accessTokenResponseClient(
return this;
}

/**
* Sets the {@link ApplicationEventPublisher} used when an access token is
* refreshed.
* @param eventPublisher the {@link ApplicationEventPublisher}
* @return the {@link RefreshTokenGrantBuilder}
*/
public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
return this;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the access
* token expiry. An access token is considered expired if
Expand Down Expand Up @@ -414,6 +428,9 @@ public OAuth2AuthorizedClientProvider build() {
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.eventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import java.util.HashSet;
import java.util.Set;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
Expand All @@ -43,10 +46,13 @@
* @see OAuth2AuthorizedClientProvider
* @see DefaultRefreshTokenTokenResponseClient
*/
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
public final class RefreshTokenOAuth2AuthorizedClientProvider
implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware {

private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();

private ApplicationEventPublisher eventPublisher;

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();
Expand Down Expand Up @@ -91,8 +97,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(),
authorizedClient.getRefreshToken(), scopes);
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest);
return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(),
context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());

OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient(
authorizedClient.getClientRegistration(), context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());

if (this.eventPublisher != null) {
this.eventPublisher
.publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse));
}

return updatedOAuth2AuthorizedClient;
}

private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient,
Expand Down Expand Up @@ -149,4 +164,9 @@ public void setClock(Clock clock) {
this.clock = clock;
}

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.event;

import org.springframework.context.ApplicationEvent;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;

/**
* An event that is published when an OAuth2 access token is refreshed.
*/
public class OAuth2TokenRefreshedEvent extends ApplicationEvent {

private final OAuth2AuthorizedClient authorizedClient;

private final OAuth2AccessTokenResponse accessTokenResponse;

public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient,
OAuth2AccessTokenResponse accessTokenResponse) {
super(source);
this.authorizedClient = authorizedClient;
this.accessTokenResponse = accessTokenResponse;
}

public OAuth2AuthorizedClient getAuthorizedClient() {
return this.authorizedClient;
}

public OAuth2AccessTokenResponse getAccessTokenResponse() {
return this.accessTokenResponse;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public boolean supports(Class<?> authentication) {
return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication);
}

private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will want to do something else here instead of reuse this method directly. We don't want to couple the event handler with the existing AuthenticationProvider. For now, I think it's perfectly fine to duplicate this code in the listener and see how it looks.

protected OidcIdToken createOidcToken(ClientRegistration clientRegistration,
OAuth2AccessTokenResponse accessTokenResponse) {
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.oidc.authentication;

import org.springframework.context.ApplicationListener;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;

/**
* An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s
*/
public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {

private final OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider;

public RefreshOidcIdTokenHandler(
OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider) {
this.oidcAuthorizationCodeAuthenticationProvider = oidcAuthorizationCodeAuthenticationProvider;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment. We don't want to couple the listener with the AuthenticationProvider. Can you please adjust this to directly include the necessary code from the createOidcToken() method?


@Override
public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();
OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();
OidcIdToken refreshedOidcToken = this.oidcAuthorizationCodeAuthenticationProvider
.createOidcToken(authorizedClient.getClientRegistration(), accessTokenResponse);
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please adjust this to use a private SecurityContextHolderStrategy securityContextHolderStrategy field instead? It can be initialized with SecurityContextHolder.getContextHolderStrategy() and also include a setter for a custom strategy.

if (authentication instanceof OAuth2AuthenticationToken oauth2AuthenticationToken) {
if (authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser) {
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
SecurityContextHolder.getContext()
.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
oauth2AuthenticationToken.getAuthorizedClientRegistrationId()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please adjust this to create a new empty SecurityContext via this.securityContextHolderStrategy.createEmptyContext() and then set the new context via this.securityContextHolderStrategy.setContext()?

}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
Expand Down Expand Up @@ -251,4 +253,55 @@ public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllega
+ OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'");
}

@Test
public void shouldPublishEventWhenTokenRefreshed() {
OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse()
.refreshToken("new-refresh-token")
.build();
// @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
this.authorizedClientProvider.authorize(authorizationContext);
assertThat(eventPublisher.flag).isTrue();
}

@Test
public void shouldNotPublishEventWhenTokenNotRefreshed() {
OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
this.authorizedClientProvider.authorize(authorizationContext);
assertThat(eventPublisher.flag).isFalse();
}

private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher {

Boolean flag = false;

@Override
public void publishEvent(Object event) {
if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) {
this.flag = true;
}
}

}

}