Skip to content

Commit fbfedd6

Browse files
authored
[PR #8736/1b88af2 backport][3.10] Improve performance of WebSocketReader (#8743)
1 parent dba2605 commit fbfedd6

File tree

2 files changed

+105
-96
lines changed

2 files changed

+105
-96
lines changed

CHANGES/8736.misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improved performance of the WebSocket reader -- by :user:`bdraco`.

aiohttp/http_websocket.py

Lines changed: 104 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ class WSMsgType(IntEnum):
9494
error = ERROR
9595

9696

97+
MESSAGE_TYPES_WITH_CONTENT: Final = (
98+
WSMsgType.BINARY,
99+
WSMsgType.TEXT,
100+
WSMsgType.CONTINUATION,
101+
)
102+
97103
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
98104

99105

@@ -313,17 +319,101 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
313319
return True, data
314320

315321
try:
316-
return self._feed_data(data)
322+
self._feed_data(data)
317323
except Exception as exc:
318324
self._exc = exc
319325
set_exception(self.queue, exc)
320326
return True, b""
321327

322-
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
328+
return False, b""
329+
330+
def _feed_data(self, data: bytes) -> None:
323331
for fin, opcode, payload, compressed in self.parse_frame(data):
324-
if compressed and not self._decompressobj:
325-
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
326-
if opcode == WSMsgType.CLOSE:
332+
if opcode in MESSAGE_TYPES_WITH_CONTENT:
333+
# load text/binary
334+
is_continuation = opcode == WSMsgType.CONTINUATION
335+
if not fin:
336+
# got partial frame payload
337+
if not is_continuation:
338+
self._opcode = opcode
339+
self._partial += payload
340+
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
341+
raise WebSocketError(
342+
WSCloseCode.MESSAGE_TOO_BIG,
343+
"Message size {} exceeds limit {}".format(
344+
len(self._partial), self._max_msg_size
345+
),
346+
)
347+
continue
348+
349+
has_partial = bool(self._partial)
350+
if is_continuation:
351+
if self._opcode is None:
352+
raise WebSocketError(
353+
WSCloseCode.PROTOCOL_ERROR,
354+
"Continuation frame for non started message",
355+
)
356+
opcode = self._opcode
357+
self._opcode = None
358+
# previous frame was non finished
359+
# we should get continuation opcode
360+
elif has_partial:
361+
raise WebSocketError(
362+
WSCloseCode.PROTOCOL_ERROR,
363+
"The opcode in non-fin frame is expected "
364+
"to be zero, got {!r}".format(opcode),
365+
)
366+
367+
if has_partial:
368+
assembled_payload = self._partial + payload
369+
self._partial.clear()
370+
else:
371+
assembled_payload = payload
372+
373+
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
374+
raise WebSocketError(
375+
WSCloseCode.MESSAGE_TOO_BIG,
376+
"Message size {} exceeds limit {}".format(
377+
len(assembled_payload), self._max_msg_size
378+
),
379+
)
380+
381+
# Decompress process must to be done after all packets
382+
# received.
383+
if compressed:
384+
if not self._decompressobj:
385+
self._decompressobj = ZLibDecompressor(
386+
suppress_deflate_header=True
387+
)
388+
payload_merged = self._decompressobj.decompress_sync(
389+
assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size
390+
)
391+
if self._decompressobj.unconsumed_tail:
392+
left = len(self._decompressobj.unconsumed_tail)
393+
raise WebSocketError(
394+
WSCloseCode.MESSAGE_TOO_BIG,
395+
"Decompressed message size {} exceeds limit {}".format(
396+
self._max_msg_size + left, self._max_msg_size
397+
),
398+
)
399+
else:
400+
payload_merged = bytes(assembled_payload)
401+
402+
if opcode == WSMsgType.TEXT:
403+
try:
404+
text = payload_merged.decode("utf-8")
405+
except UnicodeDecodeError as exc:
406+
raise WebSocketError(
407+
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
408+
) from exc
409+
410+
self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, ""), len(text))
411+
continue
412+
413+
self.queue.feed_data(
414+
WSMessage(WSMsgType.BINARY, payload_merged, ""), len(payload_merged)
415+
)
416+
elif opcode == WSMsgType.CLOSE:
327417
if len(payload) >= 2:
328418
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
329419
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
@@ -358,90 +448,10 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
358448
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
359449
)
360450

361-
elif (
362-
opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
363-
and self._opcode is None
364-
):
451+
else:
365452
raise WebSocketError(
366453
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
367454
)
368-
else:
369-
# load text/binary
370-
if not fin:
371-
# got partial frame payload
372-
if opcode != WSMsgType.CONTINUATION:
373-
self._opcode = opcode
374-
self._partial.extend(payload)
375-
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
376-
raise WebSocketError(
377-
WSCloseCode.MESSAGE_TOO_BIG,
378-
"Message size {} exceeds limit {}".format(
379-
len(self._partial), self._max_msg_size
380-
),
381-
)
382-
else:
383-
# previous frame was non finished
384-
# we should get continuation opcode
385-
if self._partial:
386-
if opcode != WSMsgType.CONTINUATION:
387-
raise WebSocketError(
388-
WSCloseCode.PROTOCOL_ERROR,
389-
"The opcode in non-fin frame is expected "
390-
"to be zero, got {!r}".format(opcode),
391-
)
392-
393-
if opcode == WSMsgType.CONTINUATION:
394-
assert self._opcode is not None
395-
opcode = self._opcode
396-
self._opcode = None
397-
398-
self._partial.extend(payload)
399-
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
400-
raise WebSocketError(
401-
WSCloseCode.MESSAGE_TOO_BIG,
402-
"Message size {} exceeds limit {}".format(
403-
len(self._partial), self._max_msg_size
404-
),
405-
)
406-
407-
# Decompress process must to be done after all packets
408-
# received.
409-
if compressed:
410-
assert self._decompressobj is not None
411-
self._partial.extend(_WS_DEFLATE_TRAILING)
412-
payload_merged = self._decompressobj.decompress_sync(
413-
self._partial, self._max_msg_size
414-
)
415-
if self._decompressobj.unconsumed_tail:
416-
left = len(self._decompressobj.unconsumed_tail)
417-
raise WebSocketError(
418-
WSCloseCode.MESSAGE_TOO_BIG,
419-
"Decompressed message size {} exceeds limit {}".format(
420-
self._max_msg_size + left, self._max_msg_size
421-
),
422-
)
423-
else:
424-
payload_merged = bytes(self._partial)
425-
426-
self._partial.clear()
427-
428-
if opcode == WSMsgType.TEXT:
429-
try:
430-
text = payload_merged.decode("utf-8")
431-
self.queue.feed_data(
432-
WSMessage(WSMsgType.TEXT, text, ""), len(text)
433-
)
434-
except UnicodeDecodeError as exc:
435-
raise WebSocketError(
436-
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
437-
) from exc
438-
else:
439-
self.queue.feed_data(
440-
WSMessage(WSMsgType.BINARY, payload_merged, ""),
441-
len(payload_merged),
442-
)
443-
444-
return False, b""
445455

446456
def parse_frame(
447457
self, buf: bytes
@@ -521,23 +531,21 @@ def parse_frame(
521531

522532
# read payload length
523533
if self._state is WSParserState.READ_PAYLOAD_LENGTH:
524-
length = self._payload_length_flag
525-
if length == 126:
534+
length_flag = self._payload_length_flag
535+
if length_flag == 126:
526536
if buf_length - start_pos < 2:
527537
break
528538
data = buf[start_pos : start_pos + 2]
529539
start_pos += 2
530-
length = UNPACK_LEN2(data)[0]
531-
self._payload_length = length
532-
elif length > 126:
540+
self._payload_length = UNPACK_LEN2(data)[0]
541+
elif length_flag > 126:
533542
if buf_length - start_pos < 8:
534543
break
535544
data = buf[start_pos : start_pos + 8]
536545
start_pos += 8
537-
length = UNPACK_LEN3(data)[0]
538-
self._payload_length = length
546+
self._payload_length = UNPACK_LEN3(data)[0]
539547
else:
540-
self._payload_length = length
548+
self._payload_length = length_flag
541549

542550
self._state = (
543551
WSParserState.READ_PAYLOAD_MASK
@@ -560,11 +568,11 @@ def parse_frame(
560568
chunk_len = buf_length - start_pos
561569
if length >= chunk_len:
562570
self._payload_length = length - chunk_len
563-
payload.extend(buf[start_pos:])
571+
payload += buf[start_pos:]
564572
start_pos = buf_length
565573
else:
566574
self._payload_length = 0
567-
payload.extend(buf[start_pos : start_pos + length])
575+
payload += buf[start_pos : start_pos + length]
568576
start_pos = start_pos + length
569577

570578
if self._payload_length != 0:

0 commit comments

Comments
 (0)