16
16
17
17
package org .springframework .graphql .server .webflux ;
18
18
19
- import java .util .Arrays ;
20
19
import java .util .Collections ;
21
20
import java .util .List ;
22
21
40
39
* {@link RequestPredicate} implementations tailored for GraphQL reactive endpoints.
41
40
*
42
41
* @author Brian Clozel
42
+ * @author Rossen Stoyanchev
43
43
* @since 1.3.0
44
44
*/
45
45
public final class GraphQlRequestPredicates {
@@ -56,7 +56,8 @@ private GraphQlRequestPredicates() {
56
56
* @see GraphQlHttpHandler
57
57
*/
58
58
public static RequestPredicate graphQlHttp (String path ) {
59
- return new GraphQlHttpRequestPredicate (path , MediaType .APPLICATION_JSON , MediaType .APPLICATION_GRAPHQL_RESPONSE );
59
+ return new GraphQlHttpRequestPredicate (
60
+ path , List .of (MediaType .APPLICATION_JSON , MediaType .APPLICATION_GRAPHQL_RESPONSE ));
60
61
}
61
62
62
63
/**
@@ -65,59 +66,67 @@ public static RequestPredicate graphQlHttp(String path) {
65
66
* @see GraphQlSseHandler
66
67
*/
67
68
public static RequestPredicate graphQlSse (String path ) {
68
- return new GraphQlHttpRequestPredicate (path , MediaType .TEXT_EVENT_STREAM );
69
+ return new GraphQlHttpRequestPredicate (path , List . of ( MediaType .TEXT_EVENT_STREAM ) );
69
70
}
70
71
71
72
private static class GraphQlHttpRequestPredicate implements RequestPredicate {
72
73
73
74
private final PathPattern pattern ;
74
75
76
+ private final List <MediaType > contentTypes ;
77
+
75
78
private final List <MediaType > acceptedMediaTypes ;
76
79
77
80
78
- GraphQlHttpRequestPredicate (String path , MediaType ... accepted ) {
81
+ GraphQlHttpRequestPredicate (String path , List < MediaType > accepted ) {
79
82
Assert .notNull (path , "'path' must not be null" );
80
83
Assert .notEmpty (accepted , "'accepted' must not be empty" );
81
84
PathPatternParser parser = PathPatternParser .defaultInstance ;
82
85
path = parser .initFullPathPattern (path );
83
86
this .pattern = parser .parse (path );
84
- this .acceptedMediaTypes = Arrays .asList (accepted );
87
+ this .contentTypes = List .of (MediaType .APPLICATION_JSON , MediaType .parseMediaType ("application/graphql" ));
88
+ this .acceptedMediaTypes = accepted ;
85
89
}
86
90
87
91
@ Override
88
92
public boolean test (ServerRequest request ) {
89
- return methodMatch (request , HttpMethod .POST )
90
- && contentTypeMatch (request , MediaType . APPLICATION_JSON )
93
+ return httpMethodMatch (request , HttpMethod .POST )
94
+ && contentTypeMatch (request , this . contentTypes )
91
95
&& acceptMatch (request , this .acceptedMediaTypes )
92
96
&& pathMatch (request , this .pattern );
93
97
}
94
98
95
- private static boolean methodMatch (ServerRequest request , HttpMethod expected ) {
96
- HttpMethod actual = resolveMethod (request );
99
+ private static boolean httpMethodMatch (ServerRequest request , HttpMethod expected ) {
100
+ HttpMethod actual = resolveHttpMethod (request );
97
101
boolean methodMatch = expected .equals (actual );
98
102
traceMatch ("Method" , expected , actual , methodMatch );
99
103
return methodMatch ;
100
104
}
101
105
102
- private static HttpMethod resolveMethod (ServerRequest request ) {
106
+ private static HttpMethod resolveHttpMethod (ServerRequest request ) {
103
107
if (CorsUtils .isPreFlightRequest (request .exchange ().getRequest ())) {
104
- String accessControlRequestMethod =
105
- request .headers ().firstHeader (HttpHeaders .ACCESS_CONTROL_REQUEST_METHOD );
106
- if (accessControlRequestMethod != null ) {
107
- return HttpMethod .valueOf (accessControlRequestMethod );
108
+ String httpMethod = request .headers ().firstHeader (HttpHeaders .ACCESS_CONTROL_REQUEST_METHOD );
109
+ if (httpMethod != null ) {
110
+ return HttpMethod .valueOf (httpMethod );
108
111
}
109
112
}
110
113
return request .method ();
111
114
}
112
115
113
- private static boolean contentTypeMatch (ServerRequest request , MediaType expected ) {
116
+ private static boolean contentTypeMatch (ServerRequest request , List < MediaType > contentTypes ) {
114
117
if (CorsUtils .isPreFlightRequest (request .exchange ().getRequest ())) {
115
118
return true ;
116
119
}
117
120
ServerRequest .Headers headers = request .headers ();
118
121
MediaType actual = headers .contentType ().orElse (MediaType .APPLICATION_OCTET_STREAM );
119
- boolean contentTypeMatch = expected .includes (actual );
120
- traceMatch ("Content-Type" , expected , actual , contentTypeMatch );
122
+ boolean contentTypeMatch = false ;
123
+ for (MediaType contentType : contentTypes ) {
124
+ contentTypeMatch = contentType .includes (actual );
125
+ traceMatch ("Content-Type" , contentTypes , actual , contentTypeMatch );
126
+ if (contentTypeMatch ) {
127
+ break ;
128
+ }
129
+ }
121
130
return contentTypeMatch ;
122
131
}
123
132
0 commit comments