@@ -94,6 +94,12 @@ class WSMsgType(IntEnum):
94
94
error = ERROR
95
95
96
96
97
+ MESSAGE_TYPES_WITH_CONTENT : Final = (
98
+ WSMsgType .BINARY ,
99
+ WSMsgType .TEXT ,
100
+ WSMsgType .CONTINUATION ,
101
+ )
102
+
97
103
WS_KEY : Final [bytes ] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
98
104
99
105
@@ -313,17 +319,101 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
313
319
return True , data
314
320
315
321
try :
316
- return self ._feed_data (data )
322
+ self ._feed_data (data )
317
323
except Exception as exc :
318
324
self ._exc = exc
319
325
set_exception (self .queue , exc )
320
326
return True , b""
321
327
322
- def _feed_data (self , data : bytes ) -> Tuple [bool , bytes ]:
328
+ return False , b""
329
+
330
+ def _feed_data (self , data : bytes ) -> None :
323
331
for fin , opcode , payload , compressed in self .parse_frame (data ):
324
- if compressed and not self ._decompressobj :
325
- self ._decompressobj = ZLibDecompressor (suppress_deflate_header = True )
326
- if opcode == WSMsgType .CLOSE :
332
+ if opcode in MESSAGE_TYPES_WITH_CONTENT :
333
+ # load text/binary
334
+ is_continuation = opcode == WSMsgType .CONTINUATION
335
+ if not fin :
336
+ # got partial frame payload
337
+ if not is_continuation :
338
+ self ._opcode = opcode
339
+ self ._partial += payload
340
+ if self ._max_msg_size and len (self ._partial ) >= self ._max_msg_size :
341
+ raise WebSocketError (
342
+ WSCloseCode .MESSAGE_TOO_BIG ,
343
+ "Message size {} exceeds limit {}" .format (
344
+ len (self ._partial ), self ._max_msg_size
345
+ ),
346
+ )
347
+ continue
348
+
349
+ has_partial = bool (self ._partial )
350
+ if is_continuation :
351
+ if self ._opcode is None :
352
+ raise WebSocketError (
353
+ WSCloseCode .PROTOCOL_ERROR ,
354
+ "Continuation frame for non started message" ,
355
+ )
356
+ opcode = self ._opcode
357
+ self ._opcode = None
358
+ # previous frame was non finished
359
+ # we should get continuation opcode
360
+ elif has_partial :
361
+ raise WebSocketError (
362
+ WSCloseCode .PROTOCOL_ERROR ,
363
+ "The opcode in non-fin frame is expected "
364
+ "to be zero, got {!r}" .format (opcode ),
365
+ )
366
+
367
+ if has_partial :
368
+ assembled_payload = self ._partial + payload
369
+ self ._partial .clear ()
370
+ else :
371
+ assembled_payload = payload
372
+
373
+ if self ._max_msg_size and len (assembled_payload ) >= self ._max_msg_size :
374
+ raise WebSocketError (
375
+ WSCloseCode .MESSAGE_TOO_BIG ,
376
+ "Message size {} exceeds limit {}" .format (
377
+ len (assembled_payload ), self ._max_msg_size
378
+ ),
379
+ )
380
+
381
+ # Decompress process must to be done after all packets
382
+ # received.
383
+ if compressed :
384
+ if not self ._decompressobj :
385
+ self ._decompressobj = ZLibDecompressor (
386
+ suppress_deflate_header = True
387
+ )
388
+ payload_merged = self ._decompressobj .decompress_sync (
389
+ assembled_payload + _WS_DEFLATE_TRAILING , self ._max_msg_size
390
+ )
391
+ if self ._decompressobj .unconsumed_tail :
392
+ left = len (self ._decompressobj .unconsumed_tail )
393
+ raise WebSocketError (
394
+ WSCloseCode .MESSAGE_TOO_BIG ,
395
+ "Decompressed message size {} exceeds limit {}" .format (
396
+ self ._max_msg_size + left , self ._max_msg_size
397
+ ),
398
+ )
399
+ else :
400
+ payload_merged = bytes (assembled_payload )
401
+
402
+ if opcode == WSMsgType .TEXT :
403
+ try :
404
+ text = payload_merged .decode ("utf-8" )
405
+ except UnicodeDecodeError as exc :
406
+ raise WebSocketError (
407
+ WSCloseCode .INVALID_TEXT , "Invalid UTF-8 text message"
408
+ ) from exc
409
+
410
+ self .queue .feed_data (WSMessage (WSMsgType .TEXT , text , "" ), len (text ))
411
+ continue
412
+
413
+ self .queue .feed_data (
414
+ WSMessage (WSMsgType .BINARY , payload_merged , "" ), len (payload_merged )
415
+ )
416
+ elif opcode == WSMsgType .CLOSE :
327
417
if len (payload ) >= 2 :
328
418
close_code = UNPACK_CLOSE_CODE (payload [:2 ])[0 ]
329
419
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES :
@@ -358,90 +448,10 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
358
448
WSMessage (WSMsgType .PONG , payload , "" ), len (payload )
359
449
)
360
450
361
- elif (
362
- opcode not in (WSMsgType .TEXT , WSMsgType .BINARY )
363
- and self ._opcode is None
364
- ):
451
+ else :
365
452
raise WebSocketError (
366
453
WSCloseCode .PROTOCOL_ERROR , f"Unexpected opcode={ opcode !r} "
367
454
)
368
- else :
369
- # load text/binary
370
- if not fin :
371
- # got partial frame payload
372
- if opcode != WSMsgType .CONTINUATION :
373
- self ._opcode = opcode
374
- self ._partial .extend (payload )
375
- if self ._max_msg_size and len (self ._partial ) >= self ._max_msg_size :
376
- raise WebSocketError (
377
- WSCloseCode .MESSAGE_TOO_BIG ,
378
- "Message size {} exceeds limit {}" .format (
379
- len (self ._partial ), self ._max_msg_size
380
- ),
381
- )
382
- else :
383
- # previous frame was non finished
384
- # we should get continuation opcode
385
- if self ._partial :
386
- if opcode != WSMsgType .CONTINUATION :
387
- raise WebSocketError (
388
- WSCloseCode .PROTOCOL_ERROR ,
389
- "The opcode in non-fin frame is expected "
390
- "to be zero, got {!r}" .format (opcode ),
391
- )
392
-
393
- if opcode == WSMsgType .CONTINUATION :
394
- assert self ._opcode is not None
395
- opcode = self ._opcode
396
- self ._opcode = None
397
-
398
- self ._partial .extend (payload )
399
- if self ._max_msg_size and len (self ._partial ) >= self ._max_msg_size :
400
- raise WebSocketError (
401
- WSCloseCode .MESSAGE_TOO_BIG ,
402
- "Message size {} exceeds limit {}" .format (
403
- len (self ._partial ), self ._max_msg_size
404
- ),
405
- )
406
-
407
- # Decompress process must to be done after all packets
408
- # received.
409
- if compressed :
410
- assert self ._decompressobj is not None
411
- self ._partial .extend (_WS_DEFLATE_TRAILING )
412
- payload_merged = self ._decompressobj .decompress_sync (
413
- self ._partial , self ._max_msg_size
414
- )
415
- if self ._decompressobj .unconsumed_tail :
416
- left = len (self ._decompressobj .unconsumed_tail )
417
- raise WebSocketError (
418
- WSCloseCode .MESSAGE_TOO_BIG ,
419
- "Decompressed message size {} exceeds limit {}" .format (
420
- self ._max_msg_size + left , self ._max_msg_size
421
- ),
422
- )
423
- else :
424
- payload_merged = bytes (self ._partial )
425
-
426
- self ._partial .clear ()
427
-
428
- if opcode == WSMsgType .TEXT :
429
- try :
430
- text = payload_merged .decode ("utf-8" )
431
- self .queue .feed_data (
432
- WSMessage (WSMsgType .TEXT , text , "" ), len (text )
433
- )
434
- except UnicodeDecodeError as exc :
435
- raise WebSocketError (
436
- WSCloseCode .INVALID_TEXT , "Invalid UTF-8 text message"
437
- ) from exc
438
- else :
439
- self .queue .feed_data (
440
- WSMessage (WSMsgType .BINARY , payload_merged , "" ),
441
- len (payload_merged ),
442
- )
443
-
444
- return False , b""
445
455
446
456
def parse_frame (
447
457
self , buf : bytes
@@ -521,23 +531,21 @@ def parse_frame(
521
531
522
532
# read payload length
523
533
if self ._state is WSParserState .READ_PAYLOAD_LENGTH :
524
- length = self ._payload_length_flag
525
- if length == 126 :
534
+ length_flag = self ._payload_length_flag
535
+ if length_flag == 126 :
526
536
if buf_length - start_pos < 2 :
527
537
break
528
538
data = buf [start_pos : start_pos + 2 ]
529
539
start_pos += 2
530
- length = UNPACK_LEN2 (data )[0 ]
531
- self ._payload_length = length
532
- elif length > 126 :
540
+ self ._payload_length = UNPACK_LEN2 (data )[0 ]
541
+ elif length_flag > 126 :
533
542
if buf_length - start_pos < 8 :
534
543
break
535
544
data = buf [start_pos : start_pos + 8 ]
536
545
start_pos += 8
537
- length = UNPACK_LEN3 (data )[0 ]
538
- self ._payload_length = length
546
+ self ._payload_length = UNPACK_LEN3 (data )[0 ]
539
547
else :
540
- self ._payload_length = length
548
+ self ._payload_length = length_flag
541
549
542
550
self ._state = (
543
551
WSParserState .READ_PAYLOAD_MASK
@@ -560,11 +568,11 @@ def parse_frame(
560
568
chunk_len = buf_length - start_pos
561
569
if length >= chunk_len :
562
570
self ._payload_length = length - chunk_len
563
- payload . extend ( buf [start_pos :])
571
+ payload += buf [start_pos :]
564
572
start_pos = buf_length
565
573
else :
566
574
self ._payload_length = 0
567
- payload . extend ( buf [start_pos : start_pos + length ])
575
+ payload += buf [start_pos : start_pos + length ]
568
576
start_pos = start_pos + length
569
577
570
578
if self ._payload_length != 0 :
0 commit comments