Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure ID Token is updated after refresh token #16589

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -160,7 +163,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
* @since 6.2.0
*/
static final class OAuth2AuthorizedClientManagerRegistrar
implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware {

static final String BEAN_NAME = "authorizedClientManagerRegistrar";

Expand All @@ -179,6 +182,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar

private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();

private ApplicationEventPublisher eventPublisher;

private ListableBeanFactory beanFactory;

@Override
Expand Down Expand Up @@ -302,6 +307,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

if (this.eventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
}

return authorizedClientProvider;
}

Expand Down Expand Up @@ -423,6 +432,11 @@ private <T> T getBeanOfType(ResolvableType resolvableType) {
return objectProvider.getIfAvailable();
}

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.eventPublisher = applicationContext;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler;
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
Expand Down Expand Up @@ -393,6 +394,15 @@ public void init(B http) throws Exception {
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
}
http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));

RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler();
if (this.getSecurityContextHolderStrategy() != null) {
refreshOidcIdTokenHandler.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy());
}
if (jwtDecoderFactory != null) {
refreshOidcIdTokenHandler.setJwtDecoderFactory(jwtDecoderFactory);
}
registerDelegateApplicationListener(refreshOidcIdTokenHandler);
}
else {
http.authenticationProvider(new OidcAuthenticationRequestChecker());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.core.ResolvableType;
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
Expand Down Expand Up @@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

ApplicationEventPublisher applicationEventPublisher = getBeanOfType(
ResolvableType.forClass(ApplicationEventPublisher.class));
if (applicationEventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher);
}

return authorizedClientProvider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
Expand Down Expand Up @@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder {

private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;

private ApplicationEventPublisher eventPublisher;

private Duration clockSkew;

private Clock clock;
Expand All @@ -379,6 +382,17 @@ public RefreshTokenGrantBuilder accessTokenResponseClient(
return this;
}

/**
* Sets the {@link ApplicationEventPublisher} used when an access token is
* refreshed.
* @param eventPublisher the {@link ApplicationEventPublisher}
* @return the {@link RefreshTokenGrantBuilder}
*/
public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
return this;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the access
* token expiry. An access token is considered expired if
Expand Down Expand Up @@ -414,6 +428,9 @@ public OAuth2AuthorizedClientProvider build() {
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.eventPublisher != null) {
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import java.util.HashSet;
import java.util.Set;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
Expand All @@ -43,10 +46,13 @@
* @see OAuth2AuthorizedClientProvider
* @see DefaultRefreshTokenTokenResponseClient
*/
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
public final class RefreshTokenOAuth2AuthorizedClientProvider
implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware {

private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();

private ApplicationEventPublisher eventPublisher;

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();
Expand Down Expand Up @@ -91,8 +97,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(),
authorizedClient.getRefreshToken(), scopes);
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest);
return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(),
context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());

OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient(
authorizedClient.getClientRegistration(), context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());

if (this.eventPublisher != null) {
this.eventPublisher
.publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse));
}

return updatedOAuth2AuthorizedClient;
}

private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient,
Expand Down Expand Up @@ -149,4 +164,9 @@ public void setClock(Clock clock) {
this.clock = clock;
}

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.event;

import org.springframework.context.ApplicationEvent;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;

/**
* An event that is published when an OAuth2 access token is refreshed.
*/
public class OAuth2TokenRefreshedEvent extends ApplicationEvent {

private final OAuth2AuthorizedClient authorizedClient;

private final OAuth2AccessTokenResponse accessTokenResponse;

public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient,
OAuth2AccessTokenResponse accessTokenResponse) {
super(source);
this.authorizedClient = authorizedClient;
this.accessTokenResponse = accessTokenResponse;
}

public OAuth2AuthorizedClient getAuthorizedClient() {
return this.authorizedClient;
}

public OAuth2AccessTokenResponse getAccessTokenResponse() {
return this.accessTokenResponse;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.oidc.authentication;

import java.util.Map;

import org.springframework.context.ApplicationListener;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.util.Assert;

/**
* An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s
*/
public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {

private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token";

private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";

private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new OidcIdTokenDecoderFactory();

@Override
public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();

if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) {
return;
}

Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) {
return;
}
if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) {
return;
}

OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();

String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
if (idToken == null || idToken.isBlank()) {
OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE,
"ID token is missing in the token response", null);
throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString());
}

ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse);
updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken);
}

/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}

/**
* Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature
* verification. The factory returns a {@link JwtDecoder} associated to the provided
* {@link ClientRegistration}.
* @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken}
* signature verification
*/
public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}

private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser,
OidcIdToken refreshedOidcToken) {
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);

SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
oauth2Authentication.getAuthorizedClientRegistrationId()));

this.securityContextHolderStrategy.setContext(context);
}

private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
OAuth2AccessTokenResponse accessTokenResponse) {
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);
return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());
}

private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) {
try {
Map<String, Object> parameters = accessTokenResponse.getAdditionalParameters();
return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN));
}
catch (JwtException ex) {
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
}
}

}
Loading