Skip to content

Commit daf56d5

Browse files
async_conn: optimize
1 parent cf1b3e3 commit daf56d5

File tree

3 files changed

+52
-61
lines changed

3 files changed

+52
-61
lines changed

redis/asyncio/connection.py

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
Type,
2525
TypeVar,
2626
Union,
27-
cast,
2827
)
2928
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
3029

@@ -110,32 +109,32 @@ def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool):
110109

111110
def encode(self, value: EncodableT) -> EncodedT:
112111
"""Return a bytestring or bytes-like representation of the value"""
112+
if isinstance(value, str):
113+
return value.encode(self.encoding, self.encoding_errors)
113114
if isinstance(value, (bytes, memoryview)):
114115
return value
115-
if isinstance(value, bool):
116-
# special case bool since it is a subclass of int
117-
raise DataError(
118-
"Invalid input of type: 'bool'. "
119-
"Convert to a bytes, string, int or float first."
120-
)
121116
if isinstance(value, (int, float)):
117+
if isinstance(value, bool):
118+
# special case bool since it is a subclass of int
119+
raise DataError(
120+
"Invalid input of type: 'bool'. "
121+
"Convert to a bytes, string, int or float first."
122+
)
122123
return repr(value).encode()
123-
if not isinstance(value, str):
124-
# a value we don't know how to deal with. throw an error
125-
typename = value.__class__.__name__ # type: ignore[unreachable]
126-
raise DataError(
127-
f"Invalid input of type: {typename!r}. "
128-
"Convert to a bytes, string, int or float first."
129-
)
130-
return value.encode(self.encoding, self.encoding_errors)
124+
# a value we don't know how to deal with. throw an error
125+
typename = value.__class__.__name__
126+
raise DataError(
127+
f"Invalid input of type: {typename!r}. "
128+
"Convert to a bytes, string, int or float first."
129+
)
131130

132131
def decode(self, value: EncodableT, force=False) -> EncodableT:
133132
"""Return a unicode string from the bytes-like representation"""
134133
if self.decode_responses or force:
135-
if isinstance(value, memoryview):
136-
return value.tobytes().decode(self.encoding, self.encoding_errors)
137134
if isinstance(value, bytes):
138135
return value.decode(self.encoding, self.encoding_errors)
136+
if isinstance(value, memoryview):
137+
return value.tobytes().decode(self.encoding, self.encoding_errors)
139138
return value
140139

141140

@@ -336,7 +335,7 @@ def purge(self):
336335
def close(self):
337336
try:
338337
self.purge()
339-
self._buffer.close() # type: ignore[union-attr]
338+
self._buffer.close()
340339
except Exception:
341340
# issue #633 suggests the purge/close somehow raised a
342341
# BadFileDescriptor error. Perhaps the client ran out of
@@ -466,7 +465,7 @@ def on_disconnect(self):
466465
self._next_response = False
467466

468467
async def can_read(self, timeout: float):
469-
if not self._reader:
468+
if not self._stream or not self._reader:
470469
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
471470

472471
if self._next_response is False:
@@ -480,14 +479,14 @@ async def read_from_socket(
480479
timeout: Union[float, None, _Sentinel] = SENTINEL,
481480
raise_on_timeout: bool = True,
482481
):
483-
if self._stream is None or self._reader is None:
484-
raise RedisError("Parser already closed.")
485-
486482
timeout = self._socket_timeout if timeout is SENTINEL else timeout
487483
try:
488-
async with async_timeout.timeout(timeout):
484+
if timeout is None:
489485
buffer = await self._stream.read(self._read_size)
490-
if not isinstance(buffer, bytes) or len(buffer) == 0:
486+
else:
487+
async with async_timeout.timeout(timeout):
488+
buffer = await self._stream.read(self._read_size)
489+
if not isinstance(buffer, bytes) or not buffer:
491490
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
492491
self._reader.feed(buffer)
493492
# data was read from the socket and added to the buffer.
@@ -516,9 +515,6 @@ async def read_response(
516515
self.on_disconnect()
517516
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
518517

519-
response: Union[
520-
EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]
521-
]
522518
# _next_response might be cached from a can_read() call
523519
if self._next_response is not False:
524520
response = self._next_response
@@ -541,8 +537,7 @@ async def read_response(
541537
and isinstance(response[0], ConnectionError)
542538
):
543539
raise response[0]
544-
# cast as there won't be a ConnectionError here.
545-
return cast(Union[EncodableT, List[EncodableT]], response)
540+
return response
546541

547542

548543
DefaultParser: Type[Union[PythonParser, HiredisParser]]
@@ -637,7 +632,7 @@ def __init__(
637632
self.socket_type = socket_type
638633
self.retry_on_timeout = retry_on_timeout
639634
if retry_on_timeout:
640-
if retry is None:
635+
if not retry:
641636
self.retry = Retry(NoBackoff(), 1)
642637
else:
643638
# deep-copy the Retry object as it is mutable
@@ -681,7 +676,7 @@ def __del__(self):
681676

682677
@property
683678
def is_connected(self):
684-
return bool(self._reader and self._writer)
679+
return self._reader and self._writer
685680

686681
def register_connect_callback(self, callback):
687682
self._connect_callbacks.append(weakref.WeakMethod(callback))
@@ -713,7 +708,7 @@ async def connect(self):
713708
raise ConnectionError(exc) from exc
714709

715710
try:
716-
if self.redis_connect_func is None:
711+
if not self.redis_connect_func:
717712
# Use the default on_connect function
718713
await self.on_connect()
719714
else:
@@ -745,7 +740,7 @@ async def _connect(self):
745740
self._reader = reader
746741
self._writer = writer
747742
sock = writer.transport.get_extra_info("socket")
748-
if sock is not None:
743+
if sock:
749744
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
750745
try:
751746
# TCP_KEEPALIVE
@@ -856,32 +851,29 @@ async def check_health(self):
856851
await self.retry.call_with_retry(self._send_ping, self._ping_failed)
857852

858853
async def _send_packed_command(self, command: Iterable[bytes]) -> None:
859-
if self._writer is None:
860-
raise RedisError("Connection already closed.")
861-
862854
self._writer.writelines(command)
863855
await self._writer.drain()
864856

865857
async def send_packed_command(
866-
self,
867-
command: Union[bytes, str, Iterable[bytes]],
868-
check_health: bool = True,
869-
):
870-
"""Send an already packed command to the Redis server"""
871-
if not self._writer:
858+
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
859+
) -> None:
860+
if not self.is_connected:
872861
await self.connect()
873-
# guard against health check recursion
874-
if check_health:
862+
elif check_health:
875863
await self.check_health()
864+
876865
try:
877866
if isinstance(command, str):
878867
command = command.encode()
879868
if isinstance(command, bytes):
880869
command = [command]
881-
await asyncio.wait_for(
882-
self._send_packed_command(command),
883-
self.socket_timeout,
884-
)
870+
if self.socket_timeout:
871+
await asyncio.wait_for(
872+
self._send_packed_command(command), self.socket_timeout
873+
)
874+
else:
875+
self._writer.writelines(command)
876+
await self._writer.drain()
885877
except asyncio.TimeoutError:
886878
await self.disconnect()
887879
raise TimeoutError("Timeout writing to socket") from None
@@ -901,8 +893,6 @@ async def send_packed_command(
901893

902894
async def send_command(self, *args, **kwargs):
903895
"""Pack and send a command to the Redis server"""
904-
if not self.is_connected:
905-
await self.connect()
906896
await self.send_packed_command(
907897
self.pack_command(*args), check_health=kwargs.get("check_health", True)
908898
)
@@ -917,7 +907,12 @@ async def read_response(self, disable_decoding: bool = False):
917907
"""Read the response from a previously sent command"""
918908
try:
919909
async with self._lock:
920-
async with async_timeout.timeout(self.socket_timeout):
910+
if self.socket_timeout:
911+
async with async_timeout.timeout(self.socket_timeout):
912+
response = await self._parser.read_response(
913+
disable_decoding=disable_decoding
914+
)
915+
else:
921916
response = await self._parser.read_response(
922917
disable_decoding=disable_decoding
923918
)
@@ -1176,10 +1171,7 @@ def __init__(
11761171
self._lock = asyncio.Lock()
11771172

11781173
def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
1179-
pieces = [
1180-
("path", self.path),
1181-
("db", self.db),
1182-
]
1174+
pieces = [("path", self.path), ("db", self.db)]
11831175
if self.client_name:
11841176
pieces.append(("client_name", self.client_name))
11851177
return pieces
@@ -1248,12 +1240,11 @@ def parse_url(url: str) -> ConnectKwargs:
12481240
parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
12491241
if parser:
12501242
try:
1251-
# We can't type this.
1252-
kwargs[name] = parser(value) # type: ignore[misc]
1243+
kwargs[name] = parser(value)
12531244
except (TypeError, ValueError):
12541245
raise ValueError(f"Invalid value for `{name}` in connection URL.")
12551246
else:
1256-
kwargs[name] = value # type: ignore[misc]
1247+
kwargs[name] = value
12571248

12581249
if parsed.username:
12591250
kwargs["username"] = unquote(parsed.username)
@@ -1358,7 +1349,7 @@ def __init__(
13581349
max_connections: Optional[int] = None,
13591350
**connection_kwargs,
13601351
):
1361-
max_connections = max_connections or 2 ** 31
1352+
max_connections = max_connections or 2**31
13621353
if not isinstance(max_connections, int) or max_connections < 0:
13631354
raise ValueError('"max_connections" must be a positive integer')
13641355

tests/test_asyncio/test_connection_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ async def r(self, create_redis):
664664

665665
def assert_interval_advanced(self, connection):
666666
diff = connection.next_health_check - asyncio.get_event_loop().time()
667-
assert self.interval > diff > (self.interval - 1)
667+
assert self.interval >= diff > (self.interval - 1)
668668

669669
async def test_health_check_runs(self, r):
670670
if r.connection:

tests/test_asyncio/test_lock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def test_blocking_timeout(self, r, event_loop):
114114
start = event_loop.time()
115115
assert not await lock2.acquire()
116116
# The elapsed duration should be less than the total blocking_timeout
117-
assert bt > (event_loop.time() - start) > bt - sleep
117+
assert bt >= (event_loop.time() - start) > bt - sleep
118118
await lock1.release()
119119

120120
async def test_context_manager(self, r):

0 commit comments

Comments
 (0)