Skip to content

Commit 6793334

Browse files
committed
Polish setJwkSelector
Make so that it runs only when selection is needed. Require the provided selector be non-null. Add Tests. Issue gh-16170
1 parent e22bc11 commit 6793334

File tree

3 files changed

+83
-21
lines changed

3 files changed

+83
-21
lines changed

Diff for: oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

+22-17
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,12 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8787

8888
private final JWKSource<SecurityContext> jwkSource;
8989

90-
private Converter<List<JWK>, JWK> jwkSelector= (jwks)->{
91-
if (jwks.size() > 1) {
92-
throw new JwtEncodingException(String.format(
93-
"Failed to select a key since there are multiple for the signing algorithm [%s]; " +
94-
"please specify a selector in NimbusJwsEncoder#setJwkSelector",jwks.get(0).getAlgorithm()));
95-
}
96-
if (jwks.isEmpty()) {
97-
throw new JwtEncodingException(
98-
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
99-
}
100-
return jwks.get(0);
90+
private Converter<List<JWK>, JWK> jwkSelector = (jwks) -> {
91+
throw new JwtEncodingException(
92+
String.format(
93+
"Failed to select a key since there are multiple for the signing algorithm [%s]; "
94+
+ "please specify a selector in NimbusJwsEncoder#setJwkSelector",
95+
jwks.get(0).getAlgorithm()));
10196
};
10297

10398
/**
@@ -108,17 +103,20 @@ public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
108103
Assert.notNull(jwkSource, "jwkSource cannot be null");
109104
this.jwkSource = jwkSource;
110105
}
106+
111107
/**
112-
* Use this strategy to reduce the list of matching JWKs down to a since one.
113-
* <p> For example, you can call {@code setJwkSelector(List::getFirst)} in order
114-
* to have this encoder select the first match.
108+
* Use this strategy to reduce the list of matching JWKs when there is more than one.
109+
* <p>
110+
* For example, you can call {@code setJwkSelector(List::getFirst)} in order to have
111+
* this encoder select the first match.
115112
*
116-
* <p> By default, the class with throw an exception if there is more than one result.
113+
* <p>
114+
* By default, the class with throw an exception.
117115
* @since 6.5
118116
*/
119117
public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
120-
if(null!=jwkSelector)
121-
this.jwkSelector = jwkSelector;
118+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
119+
this.jwkSelector = jwkSelector;
122120
}
123121

124122
@Override
@@ -149,6 +147,13 @@ private JWK selectJwk(JwsHeader headers) {
149147
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
150148
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
151149
}
150+
if (jwks.isEmpty()) {
151+
throw new JwtEncodingException(
152+
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
153+
}
154+
if (jwks.size() == 1) {
155+
return jwks.get(0);
156+
}
152157
return this.jwkSelector.convert(jwks);
153158
}
154159

Diff for: oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -59,6 +59,10 @@ public final class TestJwks {
5959
private TestJwks() {
6060
}
6161

62+
public static RSAKey.Builder rsa() {
63+
return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY);
64+
}
65+
6266
public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
6367
// @formatter:off
6468
return new RSAKey.Builder(publicKey)

Diff for: oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java

+56-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
2323
import java.util.Collections;
2424
import java.util.List;
2525

26+
import com.nimbusds.jose.JWSAlgorithm;
2627
import com.nimbusds.jose.KeySourceException;
2728
import com.nimbusds.jose.jwk.ECKey;
2829
import com.nimbusds.jose.jwk.JWK;
@@ -39,6 +40,7 @@
3940
import org.mockito.invocation.InvocationOnMock;
4041
import org.mockito.stubbing.Answer;
4142

43+
import org.springframework.core.convert.converter.Converter;
4244
import org.springframework.security.oauth2.jose.TestJwks;
4345
import org.springframework.security.oauth2.jose.TestKeys;
4446
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -51,6 +53,8 @@
5153
import static org.mockito.BDDMockito.willAnswer;
5254
import static org.mockito.Mockito.mock;
5355
import static org.mockito.Mockito.spy;
56+
import static org.mockito.Mockito.verify;
57+
import static org.mockito.Mockito.verifyNoInteractions;
5458

5559
/**
5660
* Tests for {@link NimbusJwtEncoder}.
@@ -109,7 +113,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce
109113

110114
@Test
111115
public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
112-
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
116+
RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
113117
this.jwkList.add(rsaJwk);
114118
this.jwkList.add(rsaJwk);
115119

@@ -118,7 +122,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws
118122

119123
assertThatExceptionOfType(JwtEncodingException.class)
120124
.isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)))
121-
.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
125+
.withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]");
122126
}
123127

124128
@Test
@@ -291,6 +295,55 @@ public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) {
291295
assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID());
292296
}
293297

298+
@Test
299+
public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception {
300+
JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
301+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
302+
given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk));
303+
Converter<List<JWK>, JWK> selector = mock(Converter.class);
304+
given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK);
305+
306+
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
307+
jwtEncoder.setJwkSelector(selector);
308+
309+
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
310+
jwtEncoder.encode(JwtEncoderParameters.from(claims));
311+
312+
verify(selector).convert(any());
313+
}
314+
315+
@Test
316+
public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception {
317+
JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
318+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
319+
given(jwkSource.get(any(), any())).willReturn(List.of(jwk));
320+
Converter<List<JWK>, JWK> selector = mock(Converter.class);
321+
322+
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
323+
jwtEncoder.setJwkSelector(selector);
324+
325+
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
326+
jwtEncoder.encode(JwtEncoderParameters.from(claims));
327+
328+
verifyNoInteractions(selector);
329+
}
330+
331+
@Test
332+
public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception {
333+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
334+
given(jwkSource.get(any(), any())).willReturn(List.of());
335+
Converter<List<JWK>, JWK> selector = mock(Converter.class);
336+
337+
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
338+
jwtEncoder.setJwkSelector(selector);
339+
340+
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
341+
assertThatExceptionOfType(JwtEncodingException.class)
342+
.isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims)));
343+
344+
verifyNoInteractions(selector);
345+
}
346+
294347
private static final class JwkListResultCaptor implements Answer<List<JWK>> {
295348

296349
private List<JWK> result;

0 commit comments

Comments
 (0)