Skip to content

Commit 984882f

Browse files
authored
Specify clientRegistrationId in TokenRelay filter (#2922)
1 parent 6f95267 commit 984882f

File tree

4 files changed

+187
-24
lines changed

4 files changed

+187
-24
lines changed

docs/src/main/asciidoc/spring-cloud-gateway.adoc

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,46 @@ consumer can be a pure Client (like an SSO application) or a Resource
20322032
Server.
20332033

20342034
Spring Cloud Gateway can forward OAuth2 access tokens downstream to the services
2035-
it is proxying. To add this functionality to the gateway, you need to add the `TokenRelayGatewayFilterFactory` like this:
2035+
it is proxying using the `TokenRelay` `GatewayFilter`.
2036+
2037+
The `TokenRelay` `GatewayFilter` takes one optional parameter, `clientRegistrationId`.
2038+
The following example configures a `TokenRelay` `GatewayFilter`:
2039+
2040+
.App.java
2041+
[source,java]
2042+
----
2043+
2044+
@Bean
2045+
public RouteLocator customRouteLocator(RouteLocatorBuilder builder) {
2046+
return builder.routes()
2047+
.route("resource", r -> r.path("/resource")
2048+
.filters(f -> f.tokenRelay("myregistrationid"))
2049+
.uri("http://localhost:9000"))
2050+
.build();
2051+
}
2052+
----
2053+
2054+
or this
2055+
2056+
.application.yaml
2057+
[source,yaml]
2058+
----
2059+
spring:
2060+
cloud:
2061+
gateway:
2062+
routes:
2063+
- id: resource
2064+
uri: http://localhost:9000
2065+
predicates:
2066+
- Path=/resource
2067+
filters:
2068+
- TokenRelay=myregistrationid
2069+
----
2070+
2071+
The example above specifies a `clientRegistrationId`, which can be used to obtain and forward an OAuth2 access token for any available `ClientRegistration`.
2072+
2073+
Spring Cloud Gateway can also forward the OAuth2 access token of the currently authenticated user `oauth2Login()` is used to authenticate the user.
2074+
To add this functionality to the gateway, you can omit the `clientRegistrationId` parameter like this:
20362075

20372076
.App.java
20382077
[source,java]
@@ -2073,10 +2112,10 @@ To enable this for Spring Cloud Gateway add the following dependencies
20732112

20742113
- `org.springframework.boot:spring-boot-starter-oauth2-client`
20752114

2076-
How does it work? The
2077-
{githubmaster}/src/main/java/org/springframework/cloud/gateway/security/TokenRelayGatewayFilterFactory.java[filter]
2078-
extracts an access token from the currently authenticated user,
2079-
and puts it in a request header for the downstream requests.
2115+
How does it work? The {github-code}/src/main/java/org/springframework/cloud/gateway/security/TokenRelayGatewayFilterFactory.java[filter]
2116+
extracts an OAuth2 access token from the currently authenticated user for the provided `clientRegistrationId`.
2117+
If no `clientRegistrationId` is provided, the currently authenticated user's own access token (obtained during login) is used.
2118+
In either case, the extracted access token is placed in a request header for the downstream requests.
20802119

20812120
For a full working sample see https://github.com/spring-cloud-samples/sample-gateway-oauth2login[this project].
20822121

spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/TokenRelayGatewayFilterFactory.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,10 +16,14 @@
1616

1717
package org.springframework.cloud.gateway.filter.factory;
1818

19+
import java.util.Collections;
20+
import java.util.List;
21+
1922
import reactor.core.publisher.Mono;
2023

2124
import org.springframework.beans.factory.ObjectProvider;
2225
import org.springframework.cloud.gateway.filter.GatewayFilter;
26+
import org.springframework.security.core.Authentication;
2327
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
2428
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
2529
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
@@ -29,37 +33,51 @@
2933

3034
/**
3135
* @author Joe Grandja
36+
* @author Steve Riesenberg
3237
*/
33-
public class TokenRelayGatewayFilterFactory extends AbstractGatewayFilterFactory<Object> {
38+
public class TokenRelayGatewayFilterFactory
39+
extends AbstractGatewayFilterFactory<AbstractGatewayFilterFactory.NameConfig> {
3440

3541
private final ObjectProvider<ReactiveOAuth2AuthorizedClientManager> clientManagerProvider;
3642

3743
public TokenRelayGatewayFilterFactory(ObjectProvider<ReactiveOAuth2AuthorizedClientManager> clientManagerProvider) {
38-
super(Object.class);
44+
super(NameConfig.class);
3945
this.clientManagerProvider = clientManagerProvider;
4046
}
4147

48+
@Override
49+
public List<String> shortcutFieldOrder() {
50+
return Collections.singletonList(NAME_KEY);
51+
}
52+
4253
public GatewayFilter apply() {
43-
return apply((Object) null);
54+
return apply((NameConfig) null);
4455
}
4556

4657
@Override
47-
public GatewayFilter apply(Object config) {
58+
public GatewayFilter apply(NameConfig config) {
59+
String defaultClientRegistrationId = (config == null) ? null : config.getName();
4860
return (exchange, chain) -> exchange.getPrincipal()
4961
// .log("token-relay-filter")
50-
.filter(principal -> principal instanceof OAuth2AuthenticationToken)
51-
.cast(OAuth2AuthenticationToken.class)
52-
.flatMap(authentication -> authorizedClient(exchange, authentication))
53-
.map(OAuth2AuthorizedClient::getAccessToken).map(token -> withBearerAuth(exchange, token))
62+
.filter(principal -> principal instanceof Authentication).cast(Authentication.class)
63+
.flatMap(principal -> authorizationRequest(defaultClientRegistrationId, principal))
64+
.flatMap(this::authorizedClient).map(OAuth2AuthorizedClient::getAccessToken)
65+
.map(token -> withBearerAuth(exchange, token))
5466
// TODO: adjustable behavior if empty
5567
.defaultIfEmpty(exchange).flatMap(chain::filter);
5668
}
5769

58-
private Mono<OAuth2AuthorizedClient> authorizedClient(ServerWebExchange exchange,
59-
OAuth2AuthenticationToken oauth2Authentication) {
60-
String clientRegistrationId = oauth2Authentication.getAuthorizedClientRegistrationId();
61-
OAuth2AuthorizeRequest request = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId)
62-
.principal(oauth2Authentication).build();
70+
private Mono<OAuth2AuthorizeRequest> authorizationRequest(String defaultClientRegistrationId,
71+
Authentication principal) {
72+
String clientRegistrationId = defaultClientRegistrationId;
73+
if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken) {
74+
clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId();
75+
}
76+
return Mono.justOrEmpty(clientRegistrationId).map(OAuth2AuthorizeRequest::withClientRegistrationId)
77+
.map(builder -> builder.principal(principal).build());
78+
}
79+
80+
private Mono<OAuth2AuthorizedClient> authorizedClient(OAuth2AuthorizeRequest request) {
6381
ReactiveOAuth2AuthorizedClientManager clientManager = clientManagerProvider.getIfAvailable();
6482
if (clientManager == null) {
6583
return Mono.error(new IllegalStateException(

spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/route/builder/GatewayFilterSpec.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -796,12 +796,13 @@ public GatewayFilterSpec setRequestHeaderSize(DataSize size) {
796796
}
797797

798798
/**
799-
* A filter that enables token relay.
799+
* A filter that enables token relay by extracting the access token from the currently
800+
* authenticated user and puts it in a request header for downstream requests.
800801
* @return a {@link GatewayFilterSpec} that can be used to apply additional filters
801802
*/
802803
public GatewayFilterSpec tokenRelay() {
803804
try {
804-
return filter(getBean(TokenRelayGatewayFilterFactory.class).apply(o -> {
805+
return filter(getBean(TokenRelayGatewayFilterFactory.class).apply(c -> {
805806
}));
806807
}
807808
catch (NoSuchBeanDefinitionException e) {
@@ -810,6 +811,23 @@ public GatewayFilterSpec tokenRelay() {
810811
}
811812
}
812813

814+
/**
815+
* A filter that enables token relay by extracting the access token of a specified
816+
* {@code ClientRegistration} and puts it in a request header for downstream requests.
817+
* @param clientRegistrationId the client registration id to use for building the
818+
* authorization request
819+
* @return a {@link GatewayFilterSpec} that can be used to apply additional filters
820+
*/
821+
public GatewayFilterSpec tokenRelay(String clientRegistrationId) {
822+
try {
823+
return filter(getBean(TokenRelayGatewayFilterFactory.class).apply(c -> c.setName(clientRegistrationId)));
824+
}
825+
catch (NoSuchBeanDefinitionException e) {
826+
throw new IllegalStateException("No TokenRelayGatewayFilterFactory bean was found. Did you include the "
827+
+ "org.springframework.boot:spring-boot-starter-oauth2-client dependency?");
828+
}
829+
}
830+
813831
/**
814832
* Adds hystrix execution exception headers to fallback request. Depends on @{code
815833
* org.springframework.cloud::spring-cloud-starter-netflix-hystrix} being on the

spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/TokenRelayGatewayFilterFactoryTests.java

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@
2222
import org.junit.jupiter.api.AfterEach;
2323
import org.junit.jupiter.api.BeforeEach;
2424
import org.junit.jupiter.api.Test;
25+
import org.mockito.ArgumentCaptor;
2526
import reactor.core.publisher.Mono;
2627

2728
import org.springframework.beans.factory.ObjectProvider;
2829
import org.springframework.cloud.gateway.filter.GatewayFilter;
2930
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
31+
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory.NameConfig;
3032
import org.springframework.http.HttpHeaders;
3133
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
3234
import org.springframework.mock.web.server.MockServerWebExchange;
3335
import org.springframework.security.authentication.TestingAuthenticationToken;
36+
import org.springframework.security.core.Authentication;
3437
import org.springframework.security.core.context.SecurityContextImpl;
3538
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
3639
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
@@ -46,6 +49,7 @@
4649
import static org.assertj.core.api.Assertions.assertThat;
4750
import static org.mockito.ArgumentMatchers.any;
4851
import static org.mockito.Mockito.mock;
52+
import static org.mockito.Mockito.verify;
4953
import static org.mockito.Mockito.when;
5054

5155
/**
@@ -64,7 +68,7 @@ public class TokenRelayGatewayFilterFactoryTests {
6468

6569
private GatewayFilterChain filterChain;
6670

67-
private GatewayFilter filter;
71+
private ObjectProvider<ReactiveOAuth2AuthorizedClientManager> objectProvider;
6872

6973
public TokenRelayGatewayFilterFactoryTests() {
7074
}
@@ -78,9 +82,8 @@ public void init() {
7882
when(filterChain.filter(any(ServerWebExchange.class))).thenReturn(Mono.empty());
7983

8084
authorizedClientManager = mock(ReactiveOAuth2AuthorizedClientManager.class);
81-
ObjectProvider<ReactiveOAuth2AuthorizedClientManager> objectProvider = mock(ObjectProvider.class);
85+
objectProvider = mock(ObjectProvider.class);
8286
when(objectProvider.getIfAvailable()).thenReturn(authorizedClientManager);
83-
filter = new TokenRelayGatewayFilterFactory(objectProvider).apply();
8487
}
8588

8689
@AfterEach
@@ -89,6 +92,7 @@ public void after() {
8992

9093
@Test
9194
public void emptyPrincipal() {
95+
GatewayFilter filter = new TokenRelayGatewayFilterFactory(objectProvider).apply();
9296
filter.filter(mockExchange, filterChain).block(TIMEOUT);
9397
assertThat(request.getHeaders()).doesNotContainKeys(HttpHeaders.AUTHORIZATION);
9498
}
@@ -112,10 +116,58 @@ public void whenPrincipalExistsAuthorizationHeaderAdded() {
112116
SecurityContextServerWebExchange exchange = new SecurityContextServerWebExchange(mockExchange,
113117
Mono.just(securityContext));
114118

119+
GatewayFilter filter = new TokenRelayGatewayFilterFactory(objectProvider).apply();
115120
filter.filter(exchange, filterChain).block(TIMEOUT);
116121

117122
assertThat(request.getHeaders()).containsEntry(HttpHeaders.AUTHORIZATION,
118123
Collections.singletonList("Bearer mytoken"));
124+
125+
ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor = ArgumentCaptor
126+
.forClass(OAuth2AuthorizeRequest.class);
127+
verify(authorizedClientManager).authorize(authorizeRequestCaptor.capture());
128+
129+
OAuth2AuthorizeRequest authorizeRequest = authorizeRequestCaptor.getValue();
130+
assertThat(authorizeRequest.getClientRegistrationId())
131+
.isEqualTo(authenticationToken.getAuthorizedClientRegistrationId());
132+
assertThat(authorizeRequest.getClientRegistrationId()).isNotEqualTo(clientRegistration.getRegistrationId());
133+
}
134+
135+
@Test
136+
public void whenClientRegistrationIdConfiguredAuthorizationHeaderAdded() {
137+
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
138+
when(accessToken.getTokenValue()).thenReturn("mytoken");
139+
140+
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("myregistrationid")
141+
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).clientId("myclientid")
142+
.tokenUri("mytokenuri").build();
143+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "steve", accessToken);
144+
145+
when(authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
146+
.thenReturn(Mono.just(authorizedClient));
147+
148+
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(mock(OAuth2User.class),
149+
Collections.emptyList(), "myId");
150+
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationToken);
151+
SecurityContextServerWebExchange exchange = new SecurityContextServerWebExchange(mockExchange,
152+
Mono.just(securityContext));
153+
154+
NameConfig config = new NameConfig();
155+
config.setName(clientRegistration.getRegistrationId());
156+
157+
GatewayFilter filter = new TokenRelayGatewayFilterFactory(objectProvider).apply(config);
158+
filter.filter(exchange, filterChain).block(TIMEOUT);
159+
160+
assertThat(request.getHeaders()).containsEntry(HttpHeaders.AUTHORIZATION,
161+
Collections.singletonList("Bearer mytoken"));
162+
163+
ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor = ArgumentCaptor
164+
.forClass(OAuth2AuthorizeRequest.class);
165+
verify(authorizedClientManager).authorize(authorizeRequestCaptor.capture());
166+
167+
OAuth2AuthorizeRequest authorizeRequest = authorizeRequestCaptor.getValue();
168+
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistration.getRegistrationId());
169+
assertThat(authorizeRequest.getClientRegistrationId())
170+
.isNotEqualTo(authenticationToken.getAuthorizedClientRegistrationId());
119171
}
120172

121173
@Test
@@ -124,9 +176,45 @@ public void principalIsNotOAuth2AuthenticationToken() {
124176
SecurityContextServerWebExchange exchange = new SecurityContextServerWebExchange(mockExchange,
125177
Mono.just(securityContext));
126178

179+
GatewayFilter filter = new TokenRelayGatewayFilterFactory(objectProvider).apply();
127180
filter.filter(exchange, filterChain).block(TIMEOUT);
128181

129182
assertThat(request.getHeaders()).doesNotContainKeys(HttpHeaders.AUTHORIZATION);
130183
}
131184

185+
@Test
186+
public void whenPrincipalIsNotOAuth2AuthenticationTokenAndClientRegistrationIdConfiguredAuthorizationHeaderAdded() {
187+
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
188+
when(accessToken.getTokenValue()).thenReturn("mytoken");
189+
190+
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("myregistrationid")
191+
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).clientId("myclientid")
192+
.tokenUri("mytokenuri").build();
193+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "steve", accessToken);
194+
195+
when(authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
196+
.thenReturn(Mono.just(authorizedClient));
197+
198+
Authentication authenticationToken = new TestingAuthenticationToken("my", null);
199+
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationToken);
200+
SecurityContextServerWebExchange exchange = new SecurityContextServerWebExchange(mockExchange,
201+
Mono.just(securityContext));
202+
203+
NameConfig config = new NameConfig();
204+
config.setName(clientRegistration.getRegistrationId());
205+
206+
GatewayFilter filter = new TokenRelayGatewayFilterFactory(objectProvider).apply(config);
207+
filter.filter(exchange, filterChain).block(TIMEOUT);
208+
209+
assertThat(request.getHeaders()).containsEntry(HttpHeaders.AUTHORIZATION,
210+
Collections.singletonList("Bearer mytoken"));
211+
212+
ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor = ArgumentCaptor
213+
.forClass(OAuth2AuthorizeRequest.class);
214+
verify(authorizedClientManager).authorize(authorizeRequestCaptor.capture());
215+
216+
OAuth2AuthorizeRequest authorizeRequest = authorizeRequestCaptor.getValue();
217+
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistration.getRegistrationId());
218+
}
219+
132220
}

0 commit comments

Comments
 (0)