Skip to content

Commit ac13b55

Browse files
ankurpathakjzheaux
authored andcommitted
HeaderWriterFilter writes headers at beginning
Add support for HeaderWriterFilter to write headers at the beginning of the request Fixes: gh-6501
1 parent fba2561 commit ac13b55

File tree

3 files changed

+181
-1
lines changed

3 files changed

+181
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.annotation.authentication.configurers;
18+
19+
import javax.servlet.Filter;
20+
import javax.servlet.ServletException;
21+
import java.io.IOException;
22+
import java.util.LinkedHashMap;
23+
import java.util.List;
24+
import java.util.Map;
25+
26+
import org.junit.After;
27+
import org.junit.Before;
28+
import org.junit.Test;
29+
30+
import org.springframework.mock.web.MockFilterChain;
31+
import org.springframework.mock.web.MockHttpServletRequest;
32+
import org.springframework.mock.web.MockHttpServletResponse;
33+
import org.springframework.mock.web.MockServletContext;
34+
import org.springframework.security.config.annotation.ObjectPostProcessor;
35+
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
36+
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
37+
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
38+
import org.springframework.security.web.header.HeaderWriterFilter;
39+
import org.springframework.web.context.ConfigurableWebApplicationContext;
40+
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
41+
42+
import static org.assertj.core.api.Assertions.assertThat;
43+
44+
/**
45+
* Tests for {@link HeadersConfigurer}.
46+
*
47+
* @author Ankur Pathak
48+
*/
49+
public class HeadersConfigurerJavaTests {
50+
51+
private boolean allowCircularReferences = false;
52+
private MockServletContext servletContext;
53+
private MockHttpServletRequest request;
54+
private MockHttpServletResponse response;
55+
private MockFilterChain chain;
56+
private ConfigurableWebApplicationContext context;
57+
58+
59+
@Before
60+
public void setUp() {
61+
this.servletContext = new MockServletContext();
62+
this.request = new MockHttpServletRequest(this.servletContext, "GET", "");
63+
this.response = new MockHttpServletResponse();
64+
this.chain = new MockFilterChain();
65+
}
66+
67+
68+
@After
69+
public void cleanup(){
70+
if (this.context != null){
71+
this.context.close();
72+
}
73+
}
74+
75+
76+
@EnableWebSecurity
77+
public static class HeadersAtTheBeginningOfRequestConfig extends WebSecurityConfigurerAdapter {
78+
@Override
79+
protected void configure(HttpSecurity http) throws Exception {
80+
http
81+
.headers()
82+
.addObjectPostProcessor(new ObjectPostProcessor<HeaderWriterFilter>() {
83+
@Override
84+
public HeaderWriterFilter postProcess(HeaderWriterFilter filter) {
85+
filter.setShouldWriteHeadersEagerly(true);
86+
return filter;
87+
}
88+
});
89+
}
90+
}
91+
92+
@Test
93+
public void headersWrittenAtBeginningOfRequest() throws IOException, ServletException {
94+
this.context = loadConfig(HeadersAtTheBeginningOfRequestConfig.class);
95+
this.request.setSecure(true);
96+
getSpringSecurityFilterChain().doFilter(this.request, this.response, this.chain);
97+
assertThat(getResponseHeaders()).containsAllEntriesOf(new LinkedHashMap<String, String>(){{
98+
put("X-Content-Type-Options", "nosniff");
99+
put("X-Frame-Options", "DENY");
100+
put("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains");
101+
put("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate");
102+
put("Expires", "0");
103+
put("Pragma", "no-cache");
104+
put("X-XSS-Protection", "1; mode=block");
105+
}});
106+
}
107+
108+
109+
@SuppressWarnings("unchecked")
110+
private Map<String, String > getResponseHeaders() {
111+
Map<String, String> headers = new LinkedHashMap<>();
112+
this.response.getHeaderNames().forEach(name -> {
113+
List values = this.response.getHeaderValues(name);
114+
headers.put(name, String.join(",", values));
115+
});
116+
return headers;
117+
}
118+
119+
private ConfigurableWebApplicationContext loadConfig(Class<?>... configs) {
120+
AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
121+
context.register(configs);
122+
context.setAllowCircularReferences(this.allowCircularReferences);
123+
context.setServletContext(this.servletContext);
124+
context.refresh();
125+
return context;
126+
}
127+
128+
private Filter getSpringSecurityFilterChain() {
129+
return this.context.getBean("springSecurityFilterChain", Filter.class);
130+
}
131+
}

web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
5151
*/
5252
private final HeaderWriter headerWriter;
5353

54+
/**
55+
* Indicates whether to write the headers at the beginning of the request.
56+
*/
57+
private boolean shouldWriteHeadersEagerly = false;
58+
5459
/**
5560
* Creates a new instance.
5661
*
@@ -67,18 +72,41 @@ protected void doFilterInternal(HttpServletRequest request,
6772
HttpServletResponse response, FilterChain filterChain)
6873
throws ServletException, IOException {
6974

75+
if (this.shouldWriteHeadersEagerly) {
76+
doHeadersBefore(request, response, filterChain);
77+
} else {
78+
doHeadersAfter(request, response, filterChain);
79+
}
80+
}
81+
82+
private void doHeadersBefore(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
83+
this.headerWriter.writeHeaders(request, response);
84+
filterChain.doFilter(request, response);
85+
}
86+
87+
private void doHeadersAfter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
7088
HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request,
7189
response, this.headerWriter);
7290
HeaderWriterRequest headerWriterRequest = new HeaderWriterRequest(request,
7391
headerWriterResponse);
74-
7592
try {
7693
filterChain.doFilter(headerWriterRequest, headerWriterResponse);
7794
} finally {
7895
headerWriterResponse.writeHeaders();
7996
}
8097
}
8198

99+
/**
100+
* Allow writing headers at the beginning of the request.
101+
*
102+
* @param shouldWriteHeadersEagerly boolean to allow writing headers at the beginning of the request.
103+
* @author Ankur Pathak
104+
* @since 5.2
105+
*/
106+
public void setShouldWriteHeadersEagerly(boolean shouldWriteHeadersEagerly) {
107+
this.shouldWriteHeadersEagerly = shouldWriteHeadersEagerly;
108+
}
109+
82110
static class HeaderWriterResponse extends OnCommittedResponseWrapper {
83111
private final HttpServletRequest request;
84112
private final HeaderWriter headerWriter;

web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,25 @@ public void doFilterWhenRequestContainsIncludeThenHeadersStillWritten() throws E
136136

137137
verifyNoMoreInteractions(this.writer1);
138138
}
139+
140+
@Test
141+
public void headersWrittenAtBeginningOfRequest() throws Exception {
142+
HeaderWriterFilter filter = new HeaderWriterFilter(
143+
Collections.singletonList(this.writer1));
144+
filter.setShouldWriteHeadersEagerly(true);
145+
146+
MockHttpServletRequest request = new MockHttpServletRequest();
147+
MockHttpServletResponse response = new MockHttpServletResponse();
148+
149+
filter.doFilter(request, response, new FilterChain() {
150+
@Override
151+
public void doFilter(ServletRequest request, ServletResponse response)
152+
throws IOException, ServletException {
153+
verify(HeaderWriterFilterTests.this.writer1).writeHeaders(
154+
any(HttpServletRequest.class), any(HttpServletResponse.class));
155+
}
156+
});
157+
158+
verifyNoMoreInteractions(this.writer1);
159+
}
139160
}

0 commit comments

Comments
 (0)