Skip to content

Commit 188f708

Browse files
Simplify ProtocolSwitchStrategy by Leveraging ProtocolVersionParser (#627)
Unify HTTP and TLS token parsing in the Upgrade header by replacing custom version parsing with ProtocolVersionParser. This change removes redundant code and ensures that only supported protocols (HTTP/ and TLS tokens) are accepted, while all other upgrade protocols are rejected as unsupported.
1 parent ffc12f1 commit 188f708

File tree

2 files changed

+219
-20
lines changed

2 files changed

+219
-20
lines changed

httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java

+101-19
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,22 @@
2727
package org.apache.hc.client5.http.impl;
2828

2929
import java.util.Iterator;
30+
import java.util.concurrent.atomic.AtomicReference;
3031

3132
import org.apache.hc.core5.annotation.Internal;
33+
import org.apache.hc.core5.http.FormattedHeader;
34+
import org.apache.hc.core5.http.Header;
3235
import org.apache.hc.core5.http.HttpHeaders;
3336
import org.apache.hc.core5.http.HttpMessage;
37+
import org.apache.hc.core5.http.HttpVersion;
3438
import org.apache.hc.core5.http.ParseException;
3539
import org.apache.hc.core5.http.ProtocolException;
3640
import org.apache.hc.core5.http.ProtocolVersion;
37-
import org.apache.hc.core5.http.message.MessageSupport;
41+
import org.apache.hc.core5.http.ProtocolVersionParser;
3842
import org.apache.hc.core5.http.ssl.TLS;
43+
import org.apache.hc.core5.util.Args;
44+
import org.apache.hc.core5.util.CharArrayBuffer;
45+
import org.apache.hc.core5.util.Tokenizer;
3946

4047
/**
4148
* Protocol switch handler.
@@ -45,31 +52,106 @@
4552
@Internal
4653
public final class ProtocolSwitchStrategy {
4754

48-
enum ProtocolSwitch { FAILURE, TLS }
55+
private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE;
56+
57+
private static final Tokenizer TOKENIZER = Tokenizer.INSTANCE;
58+
59+
private static final Tokenizer.Delimiter UPGRADE_TOKEN_DELIMITER = Tokenizer.delimiters(',');
60+
61+
@FunctionalInterface
62+
private interface HeaderConsumer {
63+
void accept(CharSequence buffer, Tokenizer.Cursor cursor) throws ProtocolException;
64+
}
4965

5066
public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException {
51-
final Iterator<String> it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE);
67+
final AtomicReference<ProtocolVersion> tlsUpgrade = new AtomicReference<>();
5268

53-
ProtocolVersion tlsUpgrade = null;
54-
while (it.hasNext()) {
55-
final String token = it.next();
56-
if (token.startsWith("TLS")) {
57-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
58-
try {
59-
tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv"));
60-
} catch (final ParseException ex) {
61-
throw new ProtocolException("Invalid protocol: " + token);
69+
parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> {
70+
while (!cursor.atEnd()) {
71+
TOKENIZER.skipWhiteSpace(buffer, cursor);
72+
if (cursor.atEnd()) {
73+
break;
74+
}
75+
final int tokenStart = cursor.getPos();
76+
TOKENIZER.parseToken(buffer, cursor, UPGRADE_TOKEN_DELIMITER);
77+
final int tokenEnd = cursor.getPos();
78+
if (tokenStart < tokenEnd) {
79+
final ProtocolVersion version = parseProtocolToken(buffer, tokenStart, tokenEnd);
80+
if (version != null && "TLS".equalsIgnoreCase(version.getProtocol())) {
81+
tlsUpgrade.set(version);
82+
}
6283
}
63-
} else if (token.equals("HTTP/1.1")) {
64-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
84+
if (!cursor.atEnd()) {
85+
cursor.updatePos(cursor.getPos() + 1);
86+
}
87+
}
88+
});
89+
90+
final ProtocolVersion result = tlsUpgrade.get();
91+
if (result != null) {
92+
return result;
93+
} else {
94+
throw new ProtocolException("Invalid protocol switch response: no TLS version found");
95+
}
96+
}
97+
98+
private ProtocolVersion parseProtocolToken(final CharSequence buffer, final int start, final int end)
99+
throws ProtocolException {
100+
if (start >= end) {
101+
return null;
102+
}
103+
104+
if (end - start == 3) {
105+
final char c0 = buffer.charAt(start);
106+
final char c1 = buffer.charAt(start + 1);
107+
final char c2 = buffer.charAt(start + 2);
108+
if ((c0 == 'T' || c0 == 't') &&
109+
(c1 == 'L' || c1 == 'l') &&
110+
(c2 == 'S' || c2 == 's')) {
111+
return TLS.V_1_2.getVersion();
112+
}
113+
}
114+
115+
try {
116+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(start, end);
117+
final ProtocolVersion version = PROTOCOL_VERSION_PARSER.parse(buffer, cursor, null);
118+
119+
if ("TLS".equalsIgnoreCase(version.getProtocol())) {
120+
return version;
121+
} else if (version.equals(HttpVersion.HTTP_1_1)) {
122+
return null;
65123
} else {
66-
throw new ProtocolException("Unsupported protocol: " + token);
124+
throw new ProtocolException("Unsupported protocol or HTTP version: " + buffer.subSequence(start, end));
67125
}
126+
} catch (final ParseException ex) {
127+
throw new ProtocolException("Invalid protocol: " + buffer.subSequence(start, end), ex);
68128
}
69-
if (tlsUpgrade == null) {
70-
throw new ProtocolException("Invalid protocol switch response");
129+
}
130+
131+
private void parseHeaders(final HttpMessage message, final String name, final HeaderConsumer consumer)
132+
throws ProtocolException {
133+
Args.notNull(message, "Message headers");
134+
Args.notBlank(name, "Header name");
135+
final Iterator<Header> it = message.headerIterator(name);
136+
while (it.hasNext()) {
137+
parseHeader(it.next(), consumer);
71138
}
72-
return tlsUpgrade;
73139
}
74140

75-
}
141+
private void parseHeader(final Header header, final HeaderConsumer consumer) throws ProtocolException {
142+
Args.notNull(header, "Header");
143+
if (header instanceof FormattedHeader) {
144+
final CharArrayBuffer buf = ((FormattedHeader) header).getBuffer();
145+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, buf.length());
146+
cursor.updatePos(((FormattedHeader) header).getValuePos());
147+
consumer.accept(buf, cursor);
148+
} else {
149+
final String value = header.getValue();
150+
if (value == null) {
151+
return;
152+
}
153+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, value.length());
154+
consumer.accept(value, cursor);
155+
}
156+
}
157+
}

httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java

+118-1
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
import org.apache.hc.core5.http.HttpResponse;
3131
import org.apache.hc.core5.http.HttpStatus;
3232
import org.apache.hc.core5.http.ProtocolException;
33+
import org.apache.hc.core5.http.ProtocolVersion;
3334
import org.apache.hc.core5.http.message.BasicHttpResponse;
3435
import org.apache.hc.core5.http.ssl.TLS;
3536
import org.junit.jupiter.api.Assertions;
3637
import org.junit.jupiter.api.BeforeEach;
3738
import org.junit.jupiter.api.Test;
3839

3940
/**
40-
* Simple tests for {@link DefaultAuthenticationStrategy}.
41+
* Simple tests for {@link ProtocolSwitchStrategy}.
4142
*/
4243
class TestProtocolSwitchStrategy {
4344

@@ -95,4 +96,120 @@ void testSwitchInvalid() {
9596
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3));
9697
}
9798

99+
@Test
100+
void testNullToken() throws ProtocolException {
101+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
102+
response.addHeader(HttpHeaders.UPGRADE, "TLS,");
103+
response.addHeader(HttpHeaders.UPGRADE, null);
104+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
105+
}
106+
107+
@Test
108+
void testWhitespaceOnlyToken() throws ProtocolException {
109+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
110+
response.addHeader(HttpHeaders.UPGRADE, " , TLS");
111+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
112+
}
113+
114+
@Test
115+
void testUnsupportedTlsVersion() throws Exception {
116+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
117+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4");
118+
Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response));
119+
}
120+
121+
@Test
122+
void testUnsupportedTlsMajorVersion() throws Exception {
123+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
124+
response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0");
125+
Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response));
126+
}
127+
128+
@Test
129+
void testUnsupportedHttpVersion() {
130+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
131+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0");
132+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
133+
"Unsupported HTTP version: HTTP/2.0");
134+
}
135+
136+
@Test
137+
void testInvalidTlsFormat() {
138+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
139+
response.addHeader(HttpHeaders.UPGRADE, "TLS/abc");
140+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
141+
"Invalid protocol: TLS/abc");
142+
}
143+
144+
@Test
145+
void testHttp11Only() {
146+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
147+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
148+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
149+
"Invalid protocol switch response: no TLS version found");
150+
}
151+
152+
@Test
153+
void testSwitchToTlsValid_TLS_1_2() throws Exception {
154+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
155+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
156+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
157+
Assertions.assertEquals(TLS.V_1_2.getVersion(), result);
158+
}
159+
160+
@Test
161+
void testSwitchToTlsValid_TLS_1_0() throws Exception {
162+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
163+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0");
164+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
165+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
166+
}
167+
168+
@Test
169+
void testSwitchToTlsValid_TLS_1_1() throws Exception {
170+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
171+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1");
172+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
173+
Assertions.assertEquals(TLS.V_1_1.getVersion(), result);
174+
}
175+
176+
@Test
177+
void testInvalidTlsFormat_NoSlash() {
178+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
179+
response.addHeader(HttpHeaders.UPGRADE, "TLSv1");
180+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
181+
"Invalid protocol: TLSv1");
182+
}
183+
184+
@Test
185+
void testSwitchToTlsValid_TLS_1() throws Exception {
186+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
187+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1");
188+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
189+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
190+
}
191+
192+
@Test
193+
void testInvalidTlsFormat_MissingMajor() {
194+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
195+
response.addHeader(HttpHeaders.UPGRADE, "TLS/.1");
196+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
197+
"Invalid protocol: TLS/.1");
198+
}
199+
200+
@Test
201+
void testMultipleHttp11Tokens() {
202+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
203+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1");
204+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
205+
"Invalid protocol switch response: no TLS version found");
206+
}
207+
208+
@Test
209+
void testMixedInvalidAndValidTokens() {
210+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
211+
response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid");
212+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
213+
"Invalid protocol: Crap");
214+
}
98215
}

0 commit comments

Comments
 (0)