Skip to content

Commit ea7df8a

Browse files
timm4205bsharifi
authored andcommitted
Fix: Expending the cache cleanup triggers by adding DROP & ROLLBACK into the list of triggering commands
1 parent bd3af3a commit ea7df8a

File tree

4 files changed

+388
-12
lines changed

4 files changed

+388
-12
lines changed

redshift_connector/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_name(cls, i: int) -> str:
2828

2929

3030
DEFAULT_PROTOCOL_VERSION: int = ClientProtocolVersion.BINARY.value
31+
DEFAULT_MAX_PREPARED_STATEMENTS: int = 1000
3132

3233

3334
class DbApiParamstyle(Enum):

redshift_connector/core.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from redshift_connector.config import (
2121
DEFAULT_PROTOCOL_VERSION,
22+
DEFAULT_MAX_PREPARED_STATEMENTS,
2223
ClientProtocolVersion,
2324
DbApiParamstyle,
2425
_client_encoding,
@@ -421,7 +422,7 @@ def __init__(
421422
ssl: bool = True,
422423
sslmode: str = "verify-ca",
423424
timeout: typing.Optional[int] = None,
424-
max_prepared_statements: int = 1000,
425+
max_prepared_statements: int = DEFAULT_MAX_PREPARED_STATEMENTS,
425426
tcp_keepalive: typing.Optional[bool] = True,
426427
application_name: typing.Optional[str] = None,
427428
replication: typing.Optional[str] = None,
@@ -500,7 +501,7 @@ def __init__(
500501
self.notifications: deque = deque(maxlen=100)
501502
self.notices: deque = deque(maxlen=100)
502503
self.parameter_statuses: deque = deque(maxlen=100)
503-
self.max_prepared_statements: int = int(max_prepared_statements)
504+
self.max_prepared_statements: int = int(self.get_max_prepared_statement(max_prepared_statements))
504505
self._run_cursor: Cursor = Cursor(self, paramstyle=DbApiParamstyle.NAMED.value)
505506
self._client_protocol_version: int = client_protocol_version
506507
self._database = database
@@ -1845,7 +1846,8 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
18451846
# consist of "redshift_connector", statement, process id and statement number.
18461847
# e.g redshift_connector_statement_11432_2
18471848
statement_name: str = "_".join(("redshift_connector", "statement", str(pid), str(statement_num)))
1848-
statement_name_bin: bytes = statement_name.encode("ascii") + NULL_BYTE
1849+
statement_name_bin: bytes = self.get_statement_name_bin(statement_name)
1850+
18491851
# row_desc: list that used to store metadata of rows from DB
18501852
# param_funcs: type transform function
18511853
ps = {
@@ -1942,12 +1944,12 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
19421944

19431945
ps["bind_2"] = h_pack(len(output_fc)) + pack("!" + "h" * len(output_fc), *output_fc)
19441946

1945-
if len(cache["ps"]) > self.max_prepared_statements:
1947+
if len(cache["ps"]) >= self.max_prepared_statements:
19461948
for p in cache["ps"].values():
19471949
self.close_prepared_statement(p["statement_name_bin"])
19481950
cache["ps"].clear()
1949-
1950-
cache["ps"][key] = ps
1951+
if self.max_prepared_statements > 0:
1952+
cache["ps"][key] = ps
19511953

19521954
cursor._cached_rows.clear()
19531955
cursor._row_count = -1
@@ -2118,7 +2120,7 @@ def handle_COMMAND_COMPLETE(self: "Connection", data: bytes, cursor: Cursor) ->
21182120
# cursor object
21192121
cursor._redshift_row_count = len(cursor._cached_rows)
21202122

2121-
if command in (b"ALTER", b"CREATE"):
2123+
if command in (b"ALTER", b"CREATE", b"DROP", b"ROLLBACK"):
21222124
for scache in self._caches.values():
21232125
for pcache in scache.values():
21242126
for ps in pcache["ps"].values():
@@ -2638,3 +2640,14 @@ def set_idc_plugins_params(
26382640

26392641
if idc_client_display_name:
26402642
init_params["idc_client_display_name"] = idc_client_display_name
2643+
2644+
def get_statement_name_bin(self, statement_name: str) -> bytes:
2645+
# When max_prepared_statements is 0, we use an empty statement name. This creates an unnamed
2646+
# prepared statement that lasts only until the next Parse statement, avoiding "statement already exists" errors
2647+
return ("" if self.max_prepared_statements == 0 else statement_name).encode("ascii") + NULL_BYTE
2648+
2649+
def get_max_prepared_statement(self, max_prepared_statements: int) -> int:
2650+
if max_prepared_statements < 0:
2651+
_logger.error("Parameter max_prepared_statements must >= 0. Using default value %d", DEFAULT_MAX_PREPARED_STATEMENTS)
2652+
return DEFAULT_MAX_PREPARED_STATEMENTS
2653+
return max_prepared_statements

test/integration/test_query.py

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,19 +370,144 @@ def test_merge_read(con) -> None:
370370

371371

372372
def test_handle_COMMAND_COMPLETE_closed_ps(con, mocker) -> None:
373+
"""
374+
Test the handling of prepared statement cache cleanup for different SQL commands.
375+
This test verifies that DDL commands trigger cache cleanup while DML commands preserve the cache.
376+
377+
The test executes the following sequence:
378+
1. DROP TABLE IF EXISTS t1 (should clear cache)
379+
2. CREATE TABLE t1 (should clear cache)
380+
3. ALTER TABLE t1 (should clear cache)
381+
4. INSERT INTO t1 (should preserve cache)
382+
5. SELECT FROM t1 (should preserve cache)
383+
6. ROLLBACK (should clear cache)
384+
7. CREATE TABLE AS SELECT (should preserve cache)
385+
8. SELECT FROM t1 (should preserve cache)
386+
9. DROP TABLE IF EXISTS t1 (should clear cache)
387+
388+
Args:
389+
con: Database connection fixture
390+
mocker: pytest-mock fixture for creating spies
391+
"""
373392
with con.cursor() as cursor:
393+
# Create spy to track calls to close_prepared_statement
394+
spy = mocker.spy(con, "close_prepared_statement")
395+
374396
cursor.execute("drop table if exists t1")
397+
assert spy.called
398+
# Two calls expected: one for BEGIN transaction, one for DROP TABLE
399+
assert spy.call_count == 2
400+
spy.reset_mock()
375401

376-
spy = mocker.spy(con, "close_prepared_statement")
377402
cursor.execute("create table t1 (a int primary key)")
403+
assert spy.called
404+
# One call expected for CREATE TABLE
405+
assert spy.call_count == 1
406+
spy.reset_mock()
378407

379-
assert len(con._caches) == 1
380-
cache_iter = next(iter(con._caches.values())) # get first transaction
381-
assert len(next(iter(cache_iter.values()))["statement"]) == 3 # should be 3 ps in this transaction
382-
# begin transaction, drop table t1, create table t1
408+
cursor.execute("alter table t1 rename column a to b;")
383409
assert spy.called
410+
# One call expected for ALTER TABLE
411+
assert spy.call_count == 1
412+
spy.reset_mock()
413+
414+
cursor.execute("insert into t1 values(1)")
415+
assert spy.call_count == 0
416+
spy.reset_mock()
417+
418+
cursor.execute("select * from t1")
419+
assert spy.call_count == 0
420+
spy.reset_mock()
421+
422+
cursor.execute("rollback")
423+
assert spy.called
424+
# Three calls expected: INSERT, SELECT, and ROLLBACK statements
384425
assert spy.call_count == 3
426+
spy.reset_mock()
427+
428+
cursor.execute("create table t1 as (select 1)")
429+
assert spy.call_count == 0
430+
spy.reset_mock()
431+
432+
cursor.execute("select * from t1")
433+
assert spy.call_count == 0
434+
spy.reset_mock()
435+
436+
cursor.execute("drop table if exists t1")
437+
assert spy.called
438+
# Four calls expected: BEGIN, CREATE TABLE AS, SELECT, and DROP
439+
assert spy.call_count == 4
440+
spy.reset_mock()
441+
442+
# Ensure there's exactly one process in the cache
443+
assert len(con._caches) == 1
444+
# get cache for current process
445+
cache_iter = next(iter(con._caches.values()))
446+
447+
# Verify the number of prepared statements in this transaction
448+
# Should be 7 statements total from all operations
449+
assert len(next(iter(cache_iter.values()))["statement"]) == 8 # should be 8 ps in this process
450+
451+
@pytest.mark.parametrize("test_case", [
452+
{
453+
"name": "max_prepared_statements_zero",
454+
"max_prepared_statements": 0,
455+
"queries": ["SELECT 1", "SELECT 2"],
456+
"expected_close_calls": 0,
457+
"expected_cache_size": 0
458+
},
459+
{
460+
"name": "max_prepared_statements_default",
461+
"max_prepared_statements": 1000,
462+
"queries": ["SELECT 1", "SELECT 2"],
463+
"expected_close_calls": 0,
464+
"expected_cache_size": 3
465+
},
466+
{
467+
"name": "max_prepared_statements_limit_1",
468+
"max_prepared_statements": 2,
469+
"queries": ["SELECT 1", "SELECT 2", "SELECT 3"],
470+
"expected_close_calls": 2,
471+
"expected_cache_size": 2
472+
},
473+
{
474+
"name": "max_prepared_statements_limit_2",
475+
"max_prepared_statements": 2,
476+
"queries": ["SELECT 1", "SELECT 2"],
477+
"expected_close_calls": 2,
478+
"expected_cache_size": 1
479+
}
480+
])
481+
def test_max_prepared_statement(con, mocker, test_case) -> None:
482+
"""
483+
Test the prepared statement cache management functionality.
484+
This test verifies the behavior of the cache cleanup mechanism when:
485+
1. max_prepared_statements = 0: No statement will be cached
486+
2. max_prepared_statements > 0: Statements are cached up to the limit
487+
488+
:param con: Connection object
489+
:param mocker: pytest mocker fixture
490+
:param test_case: Dictionary containing test parameters:
491+
:return: None
492+
"""
493+
con.max_prepared_statements = test_case["max_prepared_statements"]
494+
with con.cursor() as cursor:
495+
# Create spy to track calls to close_prepared_statement
496+
spy = mocker.spy(con, "close_prepared_statement")
497+
498+
for query in test_case["queries"]:
499+
cursor.execute(query)
500+
501+
# Ensure there's exactly one process in the cache
502+
assert len(con._caches) == 1
503+
# Get cache for current process
504+
cache_iter = next(iter(con._caches.values()))
505+
506+
# Verify close_prepared_statement was called the expected number of times
507+
assert spy.call_count == test_case["expected_close_calls"]
385508

509+
# Verify the final cache size matches expected size
510+
assert len(next(iter(cache_iter.values()))["ps"]) == test_case["expected_cache_size"]
386511

387512
@pytest.mark.parametrize("_input", ["NO_SCHEMA_UNIVERSAL_QUERY", "EXTERNAL_SCHEMA_QUERY", "LOCAL_SCHEMA_QUERY"])
388513
def test___get_table_filter_clause_return_empty_result(con, _input) -> None:

0 commit comments

Comments
 (0)