Skip to content

Commit

Permalink
Proposed fix for missing WWW-Authenticate header
Browse files Browse the repository at this point in the history
Current implementation does not include the WWW-Authenticate
header when returning a 401 for missing/invalid credentials when
attempting to access the token endpoints.

Fixes-468

Signed-off-by: Lucian Holland <[email protected]>
  • Loading branch information
symposion committed Feb 3, 2025
1 parent b76300b commit 987a8c8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public final class OAuth2ClientAuthenticationConfigurer extends AbstractOAuth2Co

private AuthenticationFailureHandler errorResponseHandler;

private String realmName = "oauth";

/**
* Restrict for internal use only.
* @param objectPostProcessor an {@code ObjectPostProcessor}
Expand All @@ -102,6 +104,18 @@ public OAuth2ClientAuthenticationConfigurer authenticationConverter(
return this;
}

/**
* Sets the realm name for Http Basic when returning a WWW-Authenticate header on
* client authentication failure.
* @param realmName the Http Basic realm name
* @return the {@link OAuth2ClientAuthenticationConfigurer} for further configuration
*/
public OAuth2ClientAuthenticationConfigurer realmName(String realmName) {
Assert.hasText(realmName, "realmName cannot be empty");
this.realmName = realmName;
return this;
}

/**
* Sets the {@code Consumer} providing access to the {@code List} of default and
* (optionally) added {@link #authenticationConverter(AuthenticationConverter)
Expand Down Expand Up @@ -213,7 +227,7 @@ void init(HttpSecurity httpSecurity) {
void configure(HttpSecurity httpSecurity) {
AuthenticationManager authenticationManager = httpSecurity.getSharedObject(AuthenticationManager.class);
OAuth2ClientAuthenticationFilter clientAuthenticationFilter = new OAuth2ClientAuthenticationFilter(
authenticationManager, this.requestMatcher);
authenticationManager, this.requestMatcher, this.realmName);
List<AuthenticationConverter> authenticationConverters = createDefaultAuthenticationConverters();
if (!this.authenticationConverters.isEmpty()) {
authenticationConverters.addAll(0, this.authenticationConverters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter

private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();

private final String realmName;

private AuthenticationConverter authenticationConverter;

private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
Expand All @@ -104,10 +106,12 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
* @param requestMatcher the {@link RequestMatcher} used for matching against the
* {@code HttpServletRequest}
*/
public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager,
RequestMatcher requestMatcher) {
public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, RequestMatcher requestMatcher,
String realmName) {
this.realmName = realmName;
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
Assert.notNull(realmName, "realmName cannot be null");
this.authenticationManager = authenticationManager;
this.requestMatcher = requestMatcher;
// @formatter:off
Expand Down Expand Up @@ -140,9 +144,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
validateClientIdentifier(authenticationRequest);
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
filterChain.doFilter(request, response);
}
else {
this.authenticationFailureHandler.onAuthenticationFailure(request, response,
new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT));
}
filterChain.doFilter(request, response);

}
catch (OAuth2AuthenticationException ex) {
if (this.logger.isTraceEnabled()) {
Expand Down Expand Up @@ -204,20 +211,11 @@ private void onAuthenticationFailure(HttpServletRequest request, HttpServletResp

SecurityContextHolder.clearContext();

// TODO
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
// to indicate which HTTP authentication schemes are supported.
// If the client attempted to authenticate via the "Authorization" request header
// field,
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
// code and
// include the "WWW-Authenticate" response header field
// matching the authentication scheme used by the client.

OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
httpResponse.getHeaders().set("WWW-Authenticate", "Basic realm=\"" + this.realmName + "\"");
}
else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public class OAuth2ClientAuthenticationFilterTests {
public void setUp() {
this.authenticationManager = mock(AuthenticationManager.class);
this.requestMatcher = new AntPathRequestMatcher(this.filterProcessesUrl, HttpMethod.POST.name());
this.filter = new OAuth2ClientAuthenticationFilter(this.authenticationManager, this.requestMatcher);
this.filter = new OAuth2ClientAuthenticationFilter(this.authenticationManager, this.requestMatcher, "realm");
this.authenticationConverter = mock(AuthenticationConverter.class);
this.filter.setAuthenticationConverter(this.authenticationConverter);
}
Expand All @@ -93,14 +93,14 @@ public void cleanup() {

@Test
public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(null, this.requestMatcher))
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(null, this.requestMatcher, "realm"))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authenticationManager cannot be null");
}

@Test
public void constructorWhenRequestMatcherNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(this.authenticationManager, null))
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(this.authenticationManager, null, "realm"))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("requestMatcher cannot be null");
}
Expand Down

0 comments on commit 987a8c8

Please sign in to comment.