Skip to content

Add get_query_results to Athena #1496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion awswrangler/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Amazon Athena Module."""

from awswrangler.athena._read import read_sql_query, read_sql_table, unload # noqa
from awswrangler.athena._read import get_query_results, read_sql_query, read_sql_table, unload # noqa
from awswrangler.athena._utils import ( # noqa
create_athena_bucket,
create_ctas_table,
Expand All @@ -23,6 +23,7 @@
"describe_table",
"get_query_columns_types",
"get_query_execution",
"get_query_results",
"get_named_query_statement",
"get_work_group",
"repair_table",
Expand Down
90 changes: 90 additions & 0 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,96 @@ def _unload(
return query_metadata


@apply_configs
def get_query_results(
query_execution_id: str,
use_threads: Union[bool, int] = True,
boto3_session: Optional[boto3.Session] = None,
categories: Optional[List[str]] = None,
chunksize: Optional[Union[int, bool]] = None,
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
"""Get AWS Athena SQL query results as a Pandas DataFrame.

Parameters
----------
query_execution_id : str
SQL query's execution_id on AWS Athena.
use_threads : bool, int
True to enable concurrent requests, False to disable multiple threads.
If enabled os.cpu_count() will be used as the max number of threads.
If integer is provided, specified number is used.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
chunksize : Union[int, bool], optional
If passed will split the data in a Iterable of DataFrames (Memory friendly).
If `True` wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize.
If an `INTEGER` is passed Wrangler will iterate on the data by number of rows igual the received INTEGER.
s3_additional_kwargs : Optional[Dict[str, Any]]
Forwarded to botocore requests.
e.g. s3_additional_kwargs={'RequestPayer': 'requester'}
pyarrow_additional_kwargs : Optional[Dict[str, Any]]
Forward to the ParquetFile class or converting an Arrow table to Pandas, currently only an
"coerce_int96_timestamp_unit" or "timestamp_as_object" argument will be considered. If reading parquet
files where you cannot convert a timestamp to pandas Timestamp[ns] consider setting timestamp_as_object=True,
to allow for timestamp units larger than "ns". If reading parquet data that still uses INT96 (like Athena
outputs) you can use coerce_int96_timestamp_unit to specify what timestamp unit to encode INT96 to (by default
this is "ns", if you know the output parquet came from a system that encodes timestamp to a particular unit
then set this to that same unit e.g. coerce_int96_timestamp_unit="ms").

Returns
-------
Union[pd.DataFrame, Iterator[pd.DataFrame]]
Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed.

Examples
--------
>>> import awswrangler as wr
>>> res = wr.athena.get_query_results(
... query_execution_id="cbae5b41-8103-4709-95bb-887f88edd4f2"
... )

"""
query_metadata: _QueryMetadata = _get_query_metadata(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
categories=categories,
metadata_cache_manager=_cache_manager,
)
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
query_info: Dict[str, Any] = client_athena.get_query_execution(QueryExecutionId=query_execution_id)[
"QueryExecution"
]
statement_type: Optional[str] = query_info.get("StatementType")
if (statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE")) or (
statement_type == "DML" and query_info["Query"].startswith("UNLOAD")
):
return _fetch_parquet_result(
query_metadata=query_metadata,
keep_files=True,
categories=categories,
chunksize=chunksize,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
)
if statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
return _fetch_csv_result(
query_metadata=query_metadata,
keep_files=True,
chunksize=chunksize,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
raise exceptions.UndetectedType(f"""Unable to get results for: {query_info["Query"]}.""")


@apply_configs
def read_sql_query(
sql: str,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,3 +1097,49 @@ def test_start_query_execution_wait(path, glue_database, glue_table):
assert query_execution_result["Query"] == sql
assert query_execution_result["StatementType"] == "DML"
assert query_execution_result["QueryExecutionContext"]["Database"] == glue_database

def test_get_query_results(path, glue_table, glue_database):

sql = (
"SELECT CAST("
" ROW(1, ROW(2, ROW(3, '4'))) AS"
" ROW(field0 BIGINT, field1 ROW(field2 BIGINT, field3 ROW(field4 BIGINT, field5 VARCHAR)))"
") AS col0"
)

df_ctas: pd.DataFrame = wr.athena.read_sql_query(
sql=sql, database=glue_database, ctas_approach=True, unload_approach=False
)
query_id_ctas = df_ctas.query_metadata["QueryExecutionId"]
df_get_query_results_ctas = wr.athena.get_query_results(query_execution_id=query_id_ctas)
pd.testing.assert_frame_equal(df_get_query_results_ctas, df_ctas)

df_unload: pd.DataFrame = wr.athena.read_sql_query(
sql=sql, database=glue_database, ctas_approach=False, unload_approach=True, s3_output=path
)
query_id_unload = df_unload.query_metadata["QueryExecutionId"]
df_get_query_results_df_unload = wr.athena.get_query_results(query_execution_id=query_id_unload)
pd.testing.assert_frame_equal(df_get_query_results_df_unload, df_unload)

wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
wr.s3.to_parquet(
df=get_df(),
path=path,
index=True,
use_threads=True,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

reg_sql = f"SELECT * FROM {glue_table}"

df_regular: pd.DataFrame = wr.athena.read_sql_query(
sql=reg_sql, database=glue_database, ctas_approach=False, unload_approach=False

)
query_id_regular = df_regular.query_metadata["QueryExecutionId"]
df_get_query_results_df_regular = wr.athena.get_query_results(query_execution_id=query_id_regular)
pd.testing.assert_frame_equal(df_get_query_results_df_regular, df_regular)