Skip to content

Expand SQL formatter to LakeFormation #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions awswrangler/athena/_formatter.py → awswrangler/_sql_formatter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Formatting logic for Athena parameters."""
"""Formatting logic for SQL parameters."""
import datetime
import decimal
import re
from enum import Enum
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
from typing import Any, Dict, Generic, Optional, Sequence, Type, TypeVar


class _EngineType(Enum):
PRESTO = "presto"
HIVE = "hive"
PARTIQL = "partiql"

def __str__(self) -> str:
return self.value
Expand All @@ -29,14 +31,17 @@ def __str__(self) -> str:

class _NullType(_AbstractType[_NoneType]):
def __str__(self) -> str:
if self.engine == _EngineType.PARTIQL:
return "null"

return "NULL"


class _StringType(_AbstractType[str]):
supported_formats = {"s", "i"}

def __str__(self) -> str:
if self.engine == _EngineType.PRESTO:
if self.engine in [_EngineType.PRESTO, _EngineType.PARTIQL]:
return f"""'{self.data.replace("'", "''")}'"""

if self.engine == _EngineType.HIVE:
Expand All @@ -53,6 +58,9 @@ def __str__(self) -> str:

class _BooleanType(_AbstractType[bool]):
def __str__(self) -> str:
if self.engine == _EngineType.PARTIQL:
return "1" if self.data else "0"

return str(self.data).upper()


Expand All @@ -68,28 +76,44 @@ def __str__(self) -> str:

class _DecimalType(_AbstractType[decimal.Decimal]):
def __str__(self) -> str:
if self.engine == _EngineType.PARTIQL:
return f"'{self.data}'"

return f"DECIMAL '{self.data:f}'"


class _TimestampType(_AbstractType[datetime.datetime]):
def __str__(self) -> str:
if self.data.tzinfo is not None:
raise TypeError(f"Supports only timezone aware datatype, got {self.data}.")

if self.engine == _EngineType.PARTIQL:
return f"'{self.data.isoformat()}'"

return f"TIMESTAMP '{self.data.isoformat(sep=' ', timespec='milliseconds')}'"


class _DateType(_AbstractType[datetime.date]):
def __str__(self) -> str:
if self.engine == _EngineType.PARTIQL:
return f"'{self.data.isoformat()}'"

return f"DATE '{self.data.isoformat()}'"


class _ArrayType(_AbstractType[Sequence[_PythonType]]):
def __str__(self) -> str:
if self.engine == _EngineType.PARTIQL:
super().__str__()

return f"ARRAY [{', '.join(map(str, self.data))}]"


class _MapType(_AbstractType[Dict[_PythonType, _PythonTypeMapValue]]):
def __str__(self) -> str:
if self.engine == _EngineType.PARTIQL:
super().__str__()

if not self.data:
return "MAP()"

Expand Down Expand Up @@ -165,3 +189,26 @@ def _format_parameters(params: Dict[str, Any], engine: _EngineType) -> Dict[str,
processed_params[k] = str(abs_type)

return processed_params


_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")


def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine: _EngineType = _EngineType.PRESTO) -> str:
if params is None:
params = {}

processed_params = _format_parameters(params, engine=engine)

def replace(match: re.Match) -> str: # type: ignore
key = match.group(1)

if key not in processed_params:
# do not replace anything if the parameter is not provided
return str(match.group(0))

return str(processed_params[key])

sql = _PATTERN.sub(replace, sql)

return sql
30 changes: 3 additions & 27 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import csv
import logging
import re
import sys
import uuid
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
Expand All @@ -14,7 +13,7 @@
from awswrangler import _utils, catalog, exceptions, s3
from awswrangler._config import apply_configs
from awswrangler._data_types import cast_pandas_with_athena_types
from awswrangler.athena._formatter import _EngineType, _format_parameters
from awswrangler._sql_formatter import _process_sql_params
from awswrangler.athena._utils import (
_apply_query_metadata,
_empty_dataframe_response,
Expand Down Expand Up @@ -568,29 +567,6 @@ def _unload(
return query_metadata


_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")


def _process_sql_params(sql: str, params: Optional[Dict[str, Any]]) -> str:
if params is None:
params = {}

processed_params = _format_parameters(params, engine=_EngineType.PRESTO)

def replace(match: re.Match) -> str: # type: ignore
key = match.group(1)

if key not in processed_params:
# do not replace anything if the parameter is not provided
return str(match.group(0))

return str(processed_params[key])

sql = _PATTERN.sub(replace, sql)

return sql


@apply_configs
def get_query_results(
query_execution_id: str,
Expand Down Expand Up @@ -922,7 +898,7 @@ def read_sql_query(
>>> import awswrangler as wr
>>> df = wr.athena.read_sql_query(
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
... params={"name": "filtered_name", "city": "filtered_city"}
... )

"""
Expand Down Expand Up @@ -1303,7 +1279,7 @@ def unload(
>>> import awswrangler as wr
>>> res = wr.athena.unload(
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
... params={"name": "filtered_name", "city": "filtered_city"}
... )

"""
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts
from awswrangler._config import apply_configs
from awswrangler.athena._formatter import _EngineType, _format_parameters
from awswrangler._sql_formatter import _EngineType, _format_parameters
from awswrangler.catalog._utils import _catalog_id, _transaction_id

from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results, _LocalMetadataCacheManager
Expand Down
9 changes: 4 additions & 5 deletions awswrangler/lakeformation/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from awswrangler import _data_types, _utils, catalog
from awswrangler._config import apply_configs
from awswrangler._distributed import engine
from awswrangler._sql_formatter import _EngineType, _process_sql_params
from awswrangler._threading import _get_executor
from awswrangler.catalog._utils import _catalog_id, _transaction_id
from awswrangler.lakeformation._utils import commit_transaction, start_transaction, wait_query
Expand Down Expand Up @@ -157,17 +158,15 @@ def read_sql_query(
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
... database="my_db",
... query_as_of_time="1611142914",
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
... params={"name": "filtered_name", "city": "filtered_city"}
... )

"""
session: boto3.Session = _utils.ensure_session(session=boto3_session)
client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session)
commit_trans: bool = False
if params is None:
params = {}
for key, value in params.items():
sql = sql.replace(f":{key};", str(value))

sql = _process_sql_params(sql, params, engine=_EngineType.PARTIQL)

if not any([transaction_id, query_as_of_time]):
_logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, starting transaction")
Expand Down
2 changes: 1 addition & 1 deletion tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_df_list(governed=False):
df["category"] = df["category"].astype("category")

if governed:
df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported
df = df.drop(["iint8", "binary"], axis=1) # tinyint & binary currently not supported
return df


Expand Down
76 changes: 73 additions & 3 deletions tests/unit/test_lakeformation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import calendar
import datetime as dt
import logging
import time
from decimal import Decimal

import pytest

import awswrangler as wr
from awswrangler._distributed import EngineEnum, MemoryFormatEnum

from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv
from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv, get_df_list

if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
import modin.pandas as pd
Expand Down Expand Up @@ -50,9 +52,9 @@ def test_lakeformation(path, path2, glue_database, glue_table, glue_table2, use_

# Filter query
df2 = wr.lakeformation.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE iint16 = :iint16;",
sql=f'SELECT * FROM {glue_table} WHERE "string" = :city_name',
database=glue_database,
params={"iint16": 1},
params={"city_name": "Washington"},
)
assert len(df2.index) == 1

Expand Down Expand Up @@ -145,3 +147,71 @@ def test_lakeformation_multi_transaction(path, path2, glue_database, glue_table,

assert df2.shape == df4.shape
assert df2.c1.sum() == df4.c1.sum()


@pytest.mark.parametrize(
"col_name,col_value",
[
("date", dt.date(2020, 1, 1)),
("timestamp", dt.datetime(2020, 1, 1)),
("bool", True),
("decimal", Decimal(("1.99"))),
("float", 0.0),
("iint16", 1),
],
)
def test_lakeformation_partiql_formatting(path, path2, glue_database, glue_table, glue_table2, col_name, col_value):
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table2)

wr.s3.to_parquet(
df=get_df_list(governed=True),
path=path,
index=False,
boto3_session=None,
s3_additional_kwargs=None,
dataset=True,
partition_cols=["par0", "par1"],
mode="overwrite",
table=glue_table,
table_type="GOVERNED",
database=glue_database,
)

# Filter query
df = wr.lakeformation.read_sql_query(
sql=f'SELECT * FROM {glue_table} WHERE "{col_name}" = :col_value',
database=glue_database,
params={"col_value": col_value},
)
assert len(df) == 1


def test_lakeformation_partiql_formatting_escape_string(path, path2, glue_database, glue_table, glue_table2):
df = pd.DataFrame(
{
"id": [1, 2, 3],
"string": ["normal string", "'weird' string", "another normal string"],
}
)

wr.s3.to_parquet(
df=df,
path=path,
index=False,
boto3_session=None,
s3_additional_kwargs=None,
dataset=True,
mode="overwrite",
table=glue_table,
table_type="GOVERNED",
database=glue_database,
)

# Filter query
df = wr.lakeformation.read_sql_query(
sql=f'SELECT * FROM {glue_table} WHERE "string" = :col_value',
database=glue_database,
params={"col_value": "'weird' string"},
)
assert len(df) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from awswrangler.athena._formatter import _EngineType, _format_parameters
from awswrangler._sql_formatter import _EngineType, _format_parameters


@pytest.mark.parametrize("engine", [_EngineType.HIVE, _EngineType.PRESTO])
Expand Down