|
40 | 40 | import org.springframework.http.HttpMethod;
|
41 | 41 | import org.springframework.lang.Nullable;
|
42 | 42 | import org.springframework.security.config.ObjectPostProcessor;
|
| 43 | +import org.springframework.security.config.annotation.web.ServletRegistrationsSupport.RegistrationMapping; |
43 | 44 | import org.springframework.security.config.annotation.web.configurers.AbstractConfigAttributeRequestMatcherRegistry;
|
44 | 45 | import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
|
45 | 46 | import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
@@ -235,103 +236,31 @@ private boolean anyPathsDontStartWithLeadingSlash(String... patterns) {
|
235 | 236 | }
|
236 | 237 |
|
237 | 238 | private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) {
|
238 |
| - Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext); |
239 |
| - if (registrations.isEmpty()) { |
| 239 | + ServletRegistrationsSupport registrations = new ServletRegistrationsSupport(servletContext); |
| 240 | + Collection<RegistrationMapping> mappings = registrations.mappings(); |
| 241 | + if (mappings.isEmpty()) { |
240 | 242 | return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
|
241 | 243 | }
|
242 |
| - if (!hasDispatcherServlet(registrations)) { |
| 244 | + Collection<RegistrationMapping> dispatcherServletMappings = registrations.dispatcherServletMappings(); |
| 245 | + if (dispatcherServletMappings.isEmpty()) { |
243 | 246 | return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
|
244 | 247 | }
|
245 |
| - ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); |
246 |
| - if (dispatcherServlet != null) { |
247 |
| - if (registrations.size() == 1) { |
248 |
| - return mvc; |
249 |
| - } |
250 |
| - return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext); |
| 248 | + if (dispatcherServletMappings.size() > 1) { |
| 249 | + String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values()); |
| 250 | + throw new IllegalArgumentException(errorMessage); |
251 | 251 | }
|
252 |
| - dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations); |
253 |
| - if (dispatcherServlet != null) { |
254 |
| - String mapping = dispatcherServlet.getMappings().iterator().next(); |
255 |
| - mvc.setServletPath(mapping.substring(0, mapping.length() - 2)); |
256 |
| - return mvc; |
257 |
| - } |
258 |
| - String errorMessage = computeErrorMessage(registrations.values()); |
259 |
| - throw new IllegalArgumentException(errorMessage); |
260 |
| - } |
261 |
| - |
262 |
| - private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) { |
263 |
| - Map<String, ServletRegistration> mappable = new LinkedHashMap<>(); |
264 |
| - for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations() |
265 |
| - .entrySet()) { |
266 |
| - if (!entry.getValue().getMappings().isEmpty()) { |
267 |
| - mappable.put(entry.getKey(), entry.getValue()); |
268 |
| - } |
| 252 | + RegistrationMapping dispatcherServlet = dispatcherServletMappings.iterator().next(); |
| 253 | + if (mappings.size() > 1 && !dispatcherServlet.isDefault()) { |
| 254 | + String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values()); |
| 255 | + throw new IllegalArgumentException(errorMessage); |
269 | 256 | }
|
270 |
| - return mappable; |
271 |
| - } |
272 |
| - |
273 |
| - private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) { |
274 |
| - if (registrations == null) { |
275 |
| - return false; |
276 |
| - } |
277 |
| - for (ServletRegistration registration : registrations.values()) { |
278 |
| - if (isDispatcherServlet(registration)) { |
279 |
| - return true; |
280 |
| - } |
281 |
| - } |
282 |
| - return false; |
283 |
| - } |
284 |
| - |
285 |
| - private ServletRegistration requireOneRootDispatcherServlet( |
286 |
| - Map<String, ? extends ServletRegistration> registrations) { |
287 |
| - ServletRegistration rootDispatcherServlet = null; |
288 |
| - for (ServletRegistration registration : registrations.values()) { |
289 |
| - if (!isDispatcherServlet(registration)) { |
290 |
| - continue; |
291 |
| - } |
292 |
| - if (registration.getMappings().size() > 1) { |
293 |
| - return null; |
294 |
| - } |
295 |
| - if (!"/".equals(registration.getMappings().iterator().next())) { |
296 |
| - return null; |
297 |
| - } |
298 |
| - rootDispatcherServlet = registration; |
299 |
| - } |
300 |
| - return rootDispatcherServlet; |
301 |
| - } |
302 |
| - |
303 |
| - private ServletRegistration requireOnlyPathMappedDispatcherServlet( |
304 |
| - Map<String, ? extends ServletRegistration> registrations) { |
305 |
| - ServletRegistration pathDispatcherServlet = null; |
306 |
| - for (ServletRegistration registration : registrations.values()) { |
307 |
| - if (!isDispatcherServlet(registration)) { |
308 |
| - return null; |
309 |
| - } |
310 |
| - if (registration.getMappings().size() > 1) { |
311 |
| - return null; |
312 |
| - } |
313 |
| - String mapping = registration.getMappings().iterator().next(); |
314 |
| - if (!mapping.startsWith("/") || !mapping.endsWith("/*")) { |
315 |
| - return null; |
316 |
| - } |
317 |
| - if (pathDispatcherServlet != null) { |
318 |
| - return null; |
| 257 | + if (dispatcherServlet.isDefault()) { |
| 258 | + if (mappings.size() == 1) { |
| 259 | + return mvc; |
319 | 260 | }
|
320 |
| - pathDispatcherServlet = registration; |
321 |
| - } |
322 |
| - return pathDispatcherServlet; |
323 |
| - } |
324 |
| - |
325 |
| - private boolean isDispatcherServlet(ServletRegistration registration) { |
326 |
| - Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", |
327 |
| - null); |
328 |
| - try { |
329 |
| - Class<?> clazz = Class.forName(registration.getClassName()); |
330 |
| - return dispatcherServlet.isAssignableFrom(clazz); |
331 |
| - } |
332 |
| - catch (ClassNotFoundException ex) { |
333 |
| - return false; |
| 261 | + return new DispatcherServletDelegatingRequestMatcher(ant, mvc); |
334 | 262 | }
|
| 263 | + return mvc; |
335 | 264 | }
|
336 | 265 |
|
337 | 266 | private static String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
|
|
0 commit comments