Skip to content

Commit 3e0d7ec

Browse files
Add formatter for PartiQL
1 parent 23cdef4 commit 3e0d7ec

File tree

4 files changed

+98
-5
lines changed

4 files changed

+98
-5
lines changed

awswrangler/_sql_formatter.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class _EngineType(Enum):
1010
PRESTO = "presto"
1111
HIVE = "hive"
12+
PARTIQL = "partiql"
1213

1314
def __str__(self) -> str:
1415
return self.value
@@ -30,14 +31,17 @@ def __str__(self) -> str:
3031

3132
class _NullType(_AbstractType[_NoneType]):
3233
def __str__(self) -> str:
34+
if self.engine == _EngineType.PARTIQL:
35+
return "null"
36+
3337
return "NULL"
3438

3539

3640
class _StringType(_AbstractType[str]):
3741
supported_formats = {"s", "i"}
3842

3943
def __str__(self) -> str:
40-
if self.engine == _EngineType.PRESTO:
44+
if self.engine in [_EngineType.PRESTO, _EngineType.PARTIQL]:
4145
return f"""'{self.data.replace("'", "''")}'"""
4246

4347
if self.engine == _EngineType.HIVE:
@@ -54,6 +58,9 @@ def __str__(self) -> str:
5458

5559
class _BooleanType(_AbstractType[bool]):
5660
def __str__(self) -> str:
61+
if self.engine == _EngineType.PARTIQL:
62+
return "1" if self.data else "0"
63+
5764
return str(self.data).upper()
5865

5966

@@ -69,28 +76,44 @@ def __str__(self) -> str:
6976

7077
class _DecimalType(_AbstractType[decimal.Decimal]):
7178
def __str__(self) -> str:
79+
if self.engine == _EngineType.PARTIQL:
80+
return f"'{self.data}'"
81+
7282
return f"DECIMAL '{self.data:f}'"
7383

7484

7585
class _TimestampType(_AbstractType[datetime.datetime]):
7686
def __str__(self) -> str:
7787
if self.data.tzinfo is not None:
7888
raise TypeError(f"Supports only timezone aware datatype, got {self.data}.")
89+
90+
if self.engine == _EngineType.PARTIQL:
91+
return f"'{self.data.isoformat()}'"
92+
7993
return f"TIMESTAMP '{self.data.isoformat(sep=' ', timespec='milliseconds')}'"
8094

8195

8296
class _DateType(_AbstractType[datetime.date]):
8397
def __str__(self) -> str:
98+
if self.engine == _EngineType.PARTIQL:
99+
return f"'{self.data.isoformat()}'"
100+
84101
return f"DATE '{self.data.isoformat()}'"
85102

86103

87104
class _ArrayType(_AbstractType[Sequence[_PythonType]]):
88105
def __str__(self) -> str:
106+
if self.engine == _EngineType.PARTIQL:
107+
super().__str__()
108+
89109
return f"ARRAY [{', '.join(map(str, self.data))}]"
90110

91111

92112
class _MapType(_AbstractType[Dict[_PythonType, _PythonTypeMapValue]]):
93113
def __str__(self) -> str:
114+
if self.engine == _EngineType.PARTIQL:
115+
super().__str__()
116+
94117
if not self.data:
95118
return "MAP()"
96119

awswrangler/lakeformation/_read.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from awswrangler import _data_types, _utils, catalog
1111
from awswrangler._config import apply_configs
1212
from awswrangler._distributed import engine
13-
from awswrangler._sql_formatter import _process_sql_params
13+
from awswrangler._sql_formatter import _EngineType, _process_sql_params
1414
from awswrangler._threading import _get_executor
1515
from awswrangler.catalog._utils import _catalog_id, _transaction_id
1616
from awswrangler.distributed.ray import RayLogger
@@ -168,7 +168,7 @@ def read_sql_query(
168168
client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session)
169169
commit_trans: bool = False
170170

171-
sql = _process_sql_params(sql, params)
171+
sql = _process_sql_params(sql, params, engine=_EngineType.PARTIQL)
172172

173173
if not any([transaction_id, query_as_of_time]):
174174
_logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, starting transaction")

tests/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def get_df_list(governed=False):
105105
df["category"] = df["category"].astype("category")
106106

107107
if governed:
108-
df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported
108+
df = df.drop(["iint8", "binary"], axis=1) # tinyint & binary currently not supported
109109
return df
110110

111111

tests/unit/test_lakeformation.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import calendar
2+
import datetime as dt
23
import logging
34
import time
5+
from decimal import Decimal
46

57
import pytest
68

79
import awswrangler as wr
810
from awswrangler._distributed import EngineEnum, MemoryFormatEnum
911

10-
from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv
12+
from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv, get_df_list
1113

1214
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
1315
import modin.pandas as pd
@@ -145,3 +147,71 @@ def test_lakeformation_multi_transaction(path, path2, glue_database, glue_table,
145147

146148
assert df2.shape == df4.shape
147149
assert df2.c1.sum() == df4.c1.sum()
150+
151+
152+
@pytest.mark.parametrize(
153+
"col_name,col_value",
154+
[
155+
("date", dt.date(2020, 1, 1)),
156+
("timestamp", dt.datetime(2020, 1, 1)),
157+
("bool", True),
158+
("decimal", Decimal(("1.99"))),
159+
("float", 0.0),
160+
("iint16", 1),
161+
],
162+
)
163+
def test_lakeformation_partiql_formatting(path, path2, glue_database, glue_table, glue_table2, col_name, col_value):
164+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
165+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table2)
166+
167+
wr.s3.to_parquet(
168+
df=get_df_list(governed=True),
169+
path=path,
170+
index=False,
171+
boto3_session=None,
172+
s3_additional_kwargs=None,
173+
dataset=True,
174+
partition_cols=["par0", "par1"],
175+
mode="overwrite",
176+
table=glue_table,
177+
table_type="GOVERNED",
178+
database=glue_database,
179+
)
180+
181+
# Filter query
182+
df = wr.lakeformation.read_sql_query(
183+
sql=f'SELECT * FROM {glue_table} WHERE "{col_name}" = :col_value',
184+
database=glue_database,
185+
params={"col_value": col_value},
186+
)
187+
assert len(df) == 1
188+
189+
190+
def test_lakeformation_partiql_formatting_escape_string(path, path2, glue_database, glue_table, glue_table2):
191+
df = pd.DataFrame(
192+
{
193+
"id": [1, 2, 3],
194+
"string": ["normal string", "'weird' string", "another normal string"],
195+
}
196+
)
197+
198+
wr.s3.to_parquet(
199+
df=df,
200+
path=path,
201+
index=False,
202+
boto3_session=None,
203+
s3_additional_kwargs=None,
204+
dataset=True,
205+
mode="overwrite",
206+
table=glue_table,
207+
table_type="GOVERNED",
208+
database=glue_database,
209+
)
210+
211+
# Filter query
212+
df = wr.lakeformation.read_sql_query(
213+
sql=f'SELECT * FROM {glue_table} WHERE "string" = :col_value',
214+
database=glue_database,
215+
params={"col_value": "'weird' string"},
216+
)
217+
assert len(df) == 1

0 commit comments

Comments
 (0)