24
24
Type ,
25
25
TypeVar ,
26
26
Union ,
27
- cast ,
28
27
)
29
28
from urllib .parse import ParseResult , parse_qs , unquote , urlparse
30
29
@@ -110,32 +109,32 @@ def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool):
110
109
111
110
def encode (self , value : EncodableT ) -> EncodedT :
112
111
"""Return a bytestring or bytes-like representation of the value"""
112
+ if isinstance (value , str ):
113
+ return value .encode (self .encoding , self .encoding_errors )
113
114
if isinstance (value , (bytes , memoryview )):
114
115
return value
115
- if isinstance (value , bool ):
116
- # special case bool since it is a subclass of int
117
- raise DataError (
118
- "Invalid input of type: 'bool'. "
119
- "Convert to a bytes, string, int or float first."
120
- )
121
116
if isinstance (value , (int , float )):
117
+ if isinstance (value , bool ):
118
+ # special case bool since it is a subclass of int
119
+ raise DataError (
120
+ "Invalid input of type: 'bool'. "
121
+ "Convert to a bytes, string, int or float first."
122
+ )
122
123
return repr (value ).encode ()
123
- if not isinstance (value , str ):
124
- # a value we don't know how to deal with. throw an error
125
- typename = value .__class__ .__name__ # type: ignore[unreachable]
126
- raise DataError (
127
- f"Invalid input of type: { typename !r} . "
128
- "Convert to a bytes, string, int or float first."
129
- )
130
- return value .encode (self .encoding , self .encoding_errors )
124
+ # a value we don't know how to deal with. throw an error
125
+ typename = value .__class__ .__name__
126
+ raise DataError (
127
+ f"Invalid input of type: { typename !r} . "
128
+ "Convert to a bytes, string, int or float first."
129
+ )
131
130
132
131
def decode (self , value : EncodableT , force = False ) -> EncodableT :
133
132
"""Return a unicode string from the bytes-like representation"""
134
133
if self .decode_responses or force :
135
- if isinstance (value , memoryview ):
136
- return value .tobytes ().decode (self .encoding , self .encoding_errors )
137
134
if isinstance (value , bytes ):
138
135
return value .decode (self .encoding , self .encoding_errors )
136
+ if isinstance (value , memoryview ):
137
+ return value .tobytes ().decode (self .encoding , self .encoding_errors )
139
138
return value
140
139
141
140
@@ -336,7 +335,7 @@ def purge(self):
336
335
def close (self ):
337
336
try :
338
337
self .purge ()
339
- self ._buffer .close () # type: ignore[union-attr]
338
+ self ._buffer .close ()
340
339
except Exception :
341
340
# issue #633 suggests the purge/close somehow raised a
342
341
# BadFileDescriptor error. Perhaps the client ran out of
@@ -466,7 +465,7 @@ def on_disconnect(self):
466
465
self ._next_response = False
467
466
468
467
async def can_read (self , timeout : float ):
469
- if not self ._reader :
468
+ if not self ._stream or not self . _reader :
470
469
raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
471
470
472
471
if self ._next_response is False :
@@ -480,14 +479,14 @@ async def read_from_socket(
480
479
timeout : Union [float , None , _Sentinel ] = SENTINEL ,
481
480
raise_on_timeout : bool = True ,
482
481
):
483
- if self ._stream is None or self ._reader is None :
484
- raise RedisError ("Parser already closed." )
485
-
486
482
timeout = self ._socket_timeout if timeout is SENTINEL else timeout
487
483
try :
488
- async with async_timeout . timeout ( timeout ) :
484
+ if timeout is None :
489
485
buffer = await self ._stream .read (self ._read_size )
490
- if not isinstance (buffer , bytes ) or len (buffer ) == 0 :
486
+ else :
487
+ async with async_timeout .timeout (timeout ):
488
+ buffer = await self ._stream .read (self ._read_size )
489
+ if not isinstance (buffer , bytes ) or not buffer :
491
490
raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR ) from None
492
491
self ._reader .feed (buffer )
493
492
# data was read from the socket and added to the buffer.
@@ -516,9 +515,6 @@ async def read_response(
516
515
self .on_disconnect ()
517
516
raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR ) from None
518
517
519
- response : Union [
520
- EncodableT , ConnectionError , List [Union [EncodableT , ConnectionError ]]
521
- ]
522
518
# _next_response might be cached from a can_read() call
523
519
if self ._next_response is not False :
524
520
response = self ._next_response
@@ -541,8 +537,7 @@ async def read_response(
541
537
and isinstance (response [0 ], ConnectionError )
542
538
):
543
539
raise response [0 ]
544
- # cast as there won't be a ConnectionError here.
545
- return cast (Union [EncodableT , List [EncodableT ]], response )
540
+ return response
546
541
547
542
548
543
DefaultParser : Type [Union [PythonParser , HiredisParser ]]
@@ -637,7 +632,7 @@ def __init__(
637
632
self .socket_type = socket_type
638
633
self .retry_on_timeout = retry_on_timeout
639
634
if retry_on_timeout :
640
- if retry is None :
635
+ if not retry :
641
636
self .retry = Retry (NoBackoff (), 1 )
642
637
else :
643
638
# deep-copy the Retry object as it is mutable
@@ -681,7 +676,7 @@ def __del__(self):
681
676
682
677
@property
683
678
def is_connected (self ):
684
- return bool ( self ._reader and self ._writer )
679
+ return self ._reader and self ._writer
685
680
686
681
def register_connect_callback (self , callback ):
687
682
self ._connect_callbacks .append (weakref .WeakMethod (callback ))
@@ -713,7 +708,7 @@ async def connect(self):
713
708
raise ConnectionError (exc ) from exc
714
709
715
710
try :
716
- if self .redis_connect_func is None :
711
+ if not self .redis_connect_func :
717
712
# Use the default on_connect function
718
713
await self .on_connect ()
719
714
else :
@@ -745,7 +740,7 @@ async def _connect(self):
745
740
self ._reader = reader
746
741
self ._writer = writer
747
742
sock = writer .transport .get_extra_info ("socket" )
748
- if sock is not None :
743
+ if sock :
749
744
sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
750
745
try :
751
746
# TCP_KEEPALIVE
@@ -856,32 +851,29 @@ async def check_health(self):
856
851
await self .retry .call_with_retry (self ._send_ping , self ._ping_failed )
857
852
858
853
async def _send_packed_command (self , command : Iterable [bytes ]) -> None :
859
- if self ._writer is None :
860
- raise RedisError ("Connection already closed." )
861
-
862
854
self ._writer .writelines (command )
863
855
await self ._writer .drain ()
864
856
865
857
async def send_packed_command (
866
- self ,
867
- command : Union [bytes , str , Iterable [bytes ]],
868
- check_health : bool = True ,
869
- ):
870
- """Send an already packed command to the Redis server"""
871
- if not self ._writer :
858
+ self , command : Union [bytes , str , Iterable [bytes ]], check_health : bool = True
859
+ ) -> None :
860
+ if not self .is_connected :
872
861
await self .connect ()
873
- # guard against health check recursion
874
- if check_health :
862
+ elif check_health :
875
863
await self .check_health ()
864
+
876
865
try :
877
866
if isinstance (command , str ):
878
867
command = command .encode ()
879
868
if isinstance (command , bytes ):
880
869
command = [command ]
881
- await asyncio .wait_for (
882
- self ._send_packed_command (command ),
883
- self .socket_timeout ,
884
- )
870
+ if self .socket_timeout :
871
+ await asyncio .wait_for (
872
+ self ._send_packed_command (command ), self .socket_timeout
873
+ )
874
+ else :
875
+ self ._writer .writelines (command )
876
+ await self ._writer .drain ()
885
877
except asyncio .TimeoutError :
886
878
await self .disconnect ()
887
879
raise TimeoutError ("Timeout writing to socket" ) from None
@@ -901,8 +893,6 @@ async def send_packed_command(
901
893
902
894
async def send_command (self , * args , ** kwargs ):
903
895
"""Pack and send a command to the Redis server"""
904
- if not self .is_connected :
905
- await self .connect ()
906
896
await self .send_packed_command (
907
897
self .pack_command (* args ), check_health = kwargs .get ("check_health" , True )
908
898
)
@@ -917,7 +907,12 @@ async def read_response(self, disable_decoding: bool = False):
917
907
"""Read the response from a previously sent command"""
918
908
try :
919
909
async with self ._lock :
920
- async with async_timeout .timeout (self .socket_timeout ):
910
+ if self .socket_timeout :
911
+ async with async_timeout .timeout (self .socket_timeout ):
912
+ response = await self ._parser .read_response (
913
+ disable_decoding = disable_decoding
914
+ )
915
+ else :
921
916
response = await self ._parser .read_response (
922
917
disable_decoding = disable_decoding
923
918
)
@@ -1176,10 +1171,7 @@ def __init__(
1176
1171
self ._lock = asyncio .Lock ()
1177
1172
1178
1173
def repr_pieces (self ) -> Iterable [Tuple [str , Union [str , int ]]]:
1179
- pieces = [
1180
- ("path" , self .path ),
1181
- ("db" , self .db ),
1182
- ]
1174
+ pieces = [("path" , self .path ), ("db" , self .db )]
1183
1175
if self .client_name :
1184
1176
pieces .append (("client_name" , self .client_name ))
1185
1177
return pieces
@@ -1248,12 +1240,11 @@ def parse_url(url: str) -> ConnectKwargs:
1248
1240
parser = URL_QUERY_ARGUMENT_PARSERS .get (name )
1249
1241
if parser :
1250
1242
try :
1251
- # We can't type this.
1252
- kwargs [name ] = parser (value ) # type: ignore[misc]
1243
+ kwargs [name ] = parser (value )
1253
1244
except (TypeError , ValueError ):
1254
1245
raise ValueError (f"Invalid value for `{ name } ` in connection URL." )
1255
1246
else :
1256
- kwargs [name ] = value # type: ignore[misc]
1247
+ kwargs [name ] = value
1257
1248
1258
1249
if parsed .username :
1259
1250
kwargs ["username" ] = unquote (parsed .username )
@@ -1358,7 +1349,7 @@ def __init__(
1358
1349
max_connections : Optional [int ] = None ,
1359
1350
** connection_kwargs ,
1360
1351
):
1361
- max_connections = max_connections or 2 ** 31
1352
+ max_connections = max_connections or 2 ** 31
1362
1353
if not isinstance (max_connections , int ) or max_connections < 0 :
1363
1354
raise ValueError ('"max_connections" must be a positive integer' )
1364
1355
0 commit comments