|
17 | 17 |
|
18 | 18 | package org.openqa.selenium.netty.server;
|
19 | 19 |
|
| 20 | +import org.openqa.selenium.internal.Require; |
| 21 | +import org.openqa.selenium.remote.http.Message; |
| 22 | + |
20 | 23 | import io.netty.buffer.ByteBuf;
|
21 | 24 | import io.netty.buffer.Unpooled;
|
22 | 25 | import io.netty.channel.ChannelFuture;
|
|
38 | 41 | import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
|
39 | 42 | import io.netty.util.AttributeKey;
|
40 | 43 |
|
41 |
| -import org.openqa.selenium.internal.Require; |
42 |
| -import org.openqa.selenium.remote.http.Message; |
43 |
| - |
44 | 44 | import java.util.Optional;
|
45 | 45 | import java.util.function.BiFunction;
|
46 | 46 | import java.util.function.Consumer;
|
@@ -69,6 +69,27 @@ public WebSocketUpgradeHandler(
|
69 | 69 | this.factory = Require.nonNull("Factory", factory);
|
70 | 70 | }
|
71 | 71 |
|
| 72 | + private static void sendHttpResponse( |
| 73 | + ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) { |
| 74 | + // Generate an error page if response status code is not OK (200). |
| 75 | + if (res.status().code() != 200) { |
| 76 | + ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), UTF_8); |
| 77 | + res.content().writeBytes(buf); |
| 78 | + buf.release(); |
| 79 | + setContentLength(res, res.content().readableBytes()); |
| 80 | + } |
| 81 | + |
| 82 | + // Send the response and close the connection if necessary. |
| 83 | + ChannelFuture f = ctx.channel().writeAndFlush(res); |
| 84 | + if (!isKeepAlive(req) || res.status().code() != 200) { |
| 85 | + f.addListener(ChannelFutureListener.CLOSE); |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + private static String getWebSocketLocation(HttpRequest req) { |
| 90 | + return "ws://" + req.headers().get(HttpHeaderNames.HOST); |
| 91 | + } |
| 92 | + |
72 | 93 | @Override
|
73 | 94 | public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
74 | 95 | if (msg instanceof HttpRequest) {
|
@@ -100,26 +121,25 @@ private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) {
|
100 | 121 | }
|
101 | 122 |
|
102 | 123 | // Only handle the initial HTTP upgrade request
|
103 |
| - if (!(req.headers().contains("Connection", "upgrade", true) && |
104 |
| - req.headers().contains("Sec-WebSocket-Version"))) { |
| 124 | + if (!(req.headers().containsValue("Connection", "upgrade", true) && |
| 125 | + req.headers().contains("Sec-WebSocket-Version"))) { |
105 | 126 | ctx.fireChannelRead(req);
|
106 | 127 | return;
|
107 | 128 | }
|
108 | 129 |
|
109 | 130 | // Is this something we should try and handle?
|
110 | 131 | Optional<Consumer<Message>> maybeHandler = factory.apply(
|
111 | 132 | req.uri(),
|
112 |
| - msg -> { |
113 |
| - ctx.channel().writeAndFlush(Require.nonNull("Message to send", msg)); |
114 |
| - }); |
| 133 | + msg -> ctx.channel().writeAndFlush(Require.nonNull("Message to send", msg))); |
115 | 134 | if (!maybeHandler.isPresent()) {
|
116 |
| - sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST, ctx.alloc().buffer(0))); |
| 135 | + sendHttpResponse(ctx, req, |
| 136 | + new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST, ctx.alloc().buffer(0))); |
117 | 137 | return;
|
118 | 138 | }
|
119 | 139 |
|
120 | 140 | // Handshake
|
121 | 141 | WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
|
122 |
| - getWebSocketLocation(req), null, false, Integer.MAX_VALUE); |
| 142 | + getWebSocketLocation(req), null, true, Integer.MAX_VALUE); |
123 | 143 | handshaker = wsFactory.newHandshaker(req);
|
124 | 144 | if (handshaker == null) {
|
125 | 145 | WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
|
@@ -155,29 +175,8 @@ private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame fram
|
155 | 175 | }
|
156 | 176 | }
|
157 | 177 |
|
158 |
| - private static void sendHttpResponse( |
159 |
| - ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) { |
160 |
| - // Generate an error page if response status code is not OK (200). |
161 |
| - if (res.status().code() != 200) { |
162 |
| - ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), UTF_8); |
163 |
| - res.content().writeBytes(buf); |
164 |
| - buf.release(); |
165 |
| - setContentLength(res, res.content().readableBytes()); |
166 |
| - } |
167 |
| - |
168 |
| - // Send the response and close the connection if necessary. |
169 |
| - ChannelFuture f = ctx.channel().writeAndFlush(res); |
170 |
| - if (!isKeepAlive(req) || res.status().code() != 200) { |
171 |
| - f.addListener(ChannelFutureListener.CLOSE); |
172 |
| - } |
173 |
| - } |
174 |
| - |
175 | 178 | @Override
|
176 | 179 | public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
|
177 | 180 | ctx.close();
|
178 | 181 | }
|
179 |
| - |
180 |
| - private static String getWebSocketLocation(HttpRequest req) { |
181 |
| - return "ws://" + req.headers().get(HttpHeaderNames.HOST); |
182 |
| - } |
183 | 182 | }
|
0 commit comments