Skip to content

(feat): Refactor to distribute s3.read_parquet #1513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions awswrangler/_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Arrow Utilities Module (PRIVATE)."""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module serves two purposes:

  1. Collect some arrow specific methods such as converting from a Table to a DataFrame. The objective is to standardise how these operations are done across the codebase
  2. Break the circular dependency between the distributed module and other awswrangler modules. If these methods were to reside in the _utils or s3/_read module, they would cause circular import dependency issues


import datetime
import json
import logging
from typing import Any, Dict, Optional, Tuple, cast

import pandas as pd
import pyarrow as pa

_logger: logging.Logger = logging.getLogger(__name__)


def _extract_partitions_from_path(path_root: str, path: str) -> Dict[str, str]:
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
if path_root not in path:
raise Exception(f"Object {path} is not under the root path ({path_root}).")
path_wo_filename: str = path.rpartition("/")[0] + "/"
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "")
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") == 1))
if not dirs:
return {}
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=")[:2]) for x in dirs))
values_dics: Dict[str, str] = dict(values_tups)
return values_dics


def _add_table_partitions(
table: pa.Table,
path: str,
path_root: Optional[str],
) -> pa.Table:
part = _extract_partitions_from_path(path_root, path) if path_root else None
if part:
for col, value in part.items():
part_value = pa.array([value] * len(table)).dictionary_encode()
if col not in table.schema.names:
table = table.append_column(col, part_value)
else:
table = table.set_column(
table.schema.get_field_index(col),
col,
part_value,
)
return table


def _apply_timezone(df: pd.DataFrame, metadata: Dict[str, Any]) -> pd.DataFrame:
for c in metadata["columns"]:
if "field_name" in c and c["field_name"] is not None:
col_name = str(c["field_name"])
elif "name" in c and c["name"] is not None:
col_name = str(c["name"])
else:
continue
if col_name in df.columns and c["pandas_type"] == "datetimetz":
timezone: datetime.tzinfo = pa.lib.string_to_tzinfo(c["metadata"]["timezone"])
_logger.debug("applying timezone (%s) on column %s", timezone, col_name)
if hasattr(df[col_name].dtype, "tz") is False:
df[col_name] = df[col_name].dt.tz_localize(tz="UTC")
df[col_name] = df[col_name].dt.tz_convert(tz=timezone)
return df


def _table_to_df(
table: pa.Table,
kwargs: Dict[str, Any],
) -> pd.DataFrame:
"""Convert a PyArrow table to a Pandas DataFrame and apply metadata.

This method should be used across to codebase to ensure this conversion is consistent.
"""
metadata: Dict[str, Any] = {}
if table.schema.metadata is not None and b"pandas" in table.schema.metadata:
metadata = json.loads(table.schema.metadata[b"pandas"])

df = table.to_pandas(**kwargs)

if metadata:
_logger.debug("metadata: %s", metadata)
df = _apply_timezone(df=df, metadata=metadata)
return df
3 changes: 2 additions & 1 deletion awswrangler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from awswrangler import _config, exceptions
from awswrangler.__metadata__ import __version__
from awswrangler._arrow import _table_to_df
from awswrangler._config import apply_configs, config

if TYPE_CHECKING or config.distributed:
Expand Down Expand Up @@ -416,7 +417,7 @@ def table_refs_to_df(
) -> pd.DataFrame:
"""Build Pandas dataframe from list of PyArrow tables."""
if isinstance(tables[0], pa.Table):
return ensure_df_is_mutable(pa.concat_tables(tables, promote=True).to_pandas(**kwargs))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change that I would like to discuss. The column manipulations in the ensure_df_is_mutable would be too slow on a large DataFrame in the distributed case, so I had to remove it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we probably shouldn't be doing that in distributed scenario

return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example of using _table_to_df defined in the arrow module in order to ensure this conversion is consistent across the codebase

return _arrow_refs_to_df(arrow_refs=tables, kwargs=kwargs) # type: ignore


Expand Down
6 changes: 4 additions & 2 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def _fetch_parquet_result(
df = cast_pandas_with_athena_types(df=df, dtype=dtype_dict)
df = _apply_query_metadata(df=df, query_metadata=query_metadata)
return df
if not pyarrow_additional_kwargs:
pyarrow_additional_kwargs = {}
if categories:
pyarrow_additional_kwargs["categories"] = categories
ret = s3.read_parquet(
path=paths,
use_threads=use_threads,
boto3_session=boto3_session,
chunked=chunked,
categories=categories,
ignore_index=True,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
)
if chunked is False:
Expand Down
21 changes: 12 additions & 9 deletions awswrangler/distributed/_utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
"""Utilities Module for Distributed methods."""

from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List

import modin.pandas as pd
import pyarrow as pa
import ray
from modin.distributed.dataframe.pandas.partitions import from_partitions
from ray.data.impl.arrow_block import ArrowBlockAccessor
from ray.data.impl.arrow_block import ArrowBlockAccessor, ArrowRow
from ray.data.impl.remote_fn import cached_remote_fn

from awswrangler._arrow import _table_to_df


def _block_to_df(
block: Any,
kwargs: Dict[str, Any],
dtype: Optional[Dict[str, str]] = None,
) -> pa.Table:
block = ArrowBlockAccessor.for_block(block)
df = block._table.to_pandas(**kwargs) # pylint: disable=protected-access
return df.astype(dtype=dtype) if dtype else df
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was added to feature-match with non-distributed version. Do you propose to handle this differently or just remove for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so but then I could not find any other reference in the library. The only one I found was here. And even if there was one, I would move it inside this new _table_to_df method in order to standardise it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this type conversion was done in a different way (probably using map_types when going from pyarrow table to a dataframe), but it wasn't available in distributed case so this was the only crude way to do it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right ok, but do you agree that it's now solved since we are using the same _table_to_df method for both the distributed and standard implementations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

return _table_to_df(table=block._table, kwargs=kwargs) # pylint: disable=protected-access


def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Dict[str, Any]) -> pd.DataFrame:
ds = ray.data.from_arrow_refs(arrow_refs)
def _to_modin(dataset: ray.data.Dataset[ArrowRow], kwargs: Dict[str, Any]) -> pd.DataFrame:
block_to_df = cached_remote_fn(_block_to_df)
return from_partitions(
partitions=[block_to_df.remote(block=block, kwargs=kwargs) for block in ds.get_internal_block_refs()],
partitions=[block_to_df.remote(block=block, kwargs=kwargs) for block in dataset.get_internal_block_refs()],
axis=0,
index=pd.RangeIndex(start=0, stop=ds.count()),
index=pd.RangeIndex(start=0, stop=dataset.count()),
)


def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Dict[str, Any]) -> pd.DataFrame:
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), kwargs=kwargs)
7 changes: 7 additions & 0 deletions awswrangler/distributed/datasources/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Distributed Datasources Module."""

from awswrangler.distributed.datasources.parquet_datasource import ParquetDatasource

__all__ = [
"ParquetDatasource",
]
137 changes: 137 additions & 0 deletions awswrangler/distributed/datasources/parquet_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Distributed ParquetDatasource Module."""

import logging
from typing import Any, Callable, Iterator, List, Optional, Union

import numpy as np
import pyarrow as pa

# fs required to implicitly trigger S3 subsystem initialization
import pyarrow.fs # noqa: F401 pylint: disable=unused-import
import pyarrow.parquet as pq
from ray import cloudpickle
from ray.data.context import DatasetContext
from ray.data.datasource.datasource import ReadTask
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem
from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider
from ray.data.datasource.parquet_datasource import (
_deregister_parquet_file_fragment_serialization,
_register_parquet_file_fragment_serialization,
)
from ray.data.impl.output_buffer import BlockOutputBuffer

from awswrangler._arrow import _add_table_partitions

_logger: logging.Logger = logging.getLogger(__name__)

# The number of rows to read per batch. This is sized to generate 10MiB batches
# for rows about 1KiB in size.
PARQUET_READER_ROW_BATCH_SIZE = 100000


class ParquetDatasource:
"""Parquet datasource, for reading and writing Parquet files."""

# Original: https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/parquet_datasource.py
def prepare_read(
self,
parallelism: int,
use_threads: Union[bool, int],
paths: Union[str, List[str]],
schema: "pyarrow.lib.Schema",
columns: Optional[List[str]] = None,
coerce_int96_timestamp_unit: Optional[str] = None,
path_root: Optional[str] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
meta_provider: ParquetMetadataProvider = DefaultParquetMetadataProvider(),
_block_udf: Optional[Callable[..., Any]] = None,
) -> List[ReadTask]:
"""Create and return read tasks for a Parquet file-based datasource."""
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)

parquet_dataset = pq.ParquetDataset(
path_or_paths=paths,
filesystem=filesystem,
partitioning=None,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
use_legacy_dataset=False,
)

def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
# Deserialize after loading the filesystem class.
try:
_register_parquet_file_fragment_serialization() # type: ignore
pieces = cloudpickle.loads(serialized_pieces)
finally:
_deregister_parquet_file_fragment_serialization() # type: ignore

# Ensure that we're reading at least one dataset fragment.
assert len(pieces) > 0

ctx = DatasetContext.get_current()
output_buffer = BlockOutputBuffer(block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size)

_logger.debug("Reading %s parquet pieces", len(pieces))
for piece in pieces:
batches = piece.to_batches(
use_threads=use_threads,
columns=columns,
schema=schema,
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
)
for batch in batches:
# Table creation is wrapped inside _add_table_partitions
# to add columns with partition values when dataset=True
table = _add_table_partitions(
table=pa.Table.from_batches([batch], schema=schema),
path=f"s3://{piece.path}",
path_root=path_root,
)
# If the table is empty, drop it.
if table.num_rows > 0:
output_buffer.add_block(table)
if output_buffer.has_next():
yield output_buffer.next()

output_buffer.finalize()
if output_buffer.has_next():
yield output_buffer.next()

if _block_udf is not None:
# Try to infer dataset schema by passing dummy table through UDF.
dummy_table = schema.empty_table()
try:
inferred_schema = _block_udf(dummy_table).schema
inferred_schema = inferred_schema.with_metadata(schema.metadata)
except Exception: # pylint: disable=broad-except
_logger.debug(
"Failed to infer schema of dataset by passing dummy table "
"through UDF due to the following exception:",
exc_info=True,
)
inferred_schema = schema
else:
inferred_schema = schema
read_tasks = []
metadata = meta_provider.prefetch_file_metadata(parquet_dataset.pieces) or []
try:
_register_parquet_file_fragment_serialization() # type: ignore
for pieces, metadata in zip( # type: ignore
np.array_split(parquet_dataset.pieces, parallelism),
np.array_split(metadata, parallelism),
):
if len(pieces) <= 0:
continue
serialized_pieces = cloudpickle.dumps(pieces) # type: ignore
input_files = [p.path for p in pieces]
meta = meta_provider(
input_files,
inferred_schema,
pieces=pieces,
prefetched_metadata=metadata,
)
read_tasks.append(ReadTask(lambda p=serialized_pieces: read_pieces(p), meta)) # type: ignore
finally:
_deregister_parquet_file_fragment_serialization() # type: ignore

return read_tasks
16 changes: 8 additions & 8 deletions awswrangler/lakeformation/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def read_sql_query(
use_threads: bool = True,
boto3_session: Optional[boto3.Session] = None,
params: Optional[Dict[str, Any]] = None,
arrow_additional_kwargs: Optional[Dict[str, Any]] = None,
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
"""Execute PartiQL query on AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame.

Expand Down Expand Up @@ -126,10 +126,10 @@ def read_sql_query(
Dict of parameters used to format the partiQL query. Only named parameters are supported.
The dict must contain the information in the form {"name": "value"} and the SQL query must contain
`:name`.
arrow_additional_kwargs : Dict[str, Any], optional
pyarrow_additional_kwargs : Dict[str, Any], optional
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas dataframe.
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
e.g. arrow_additional_kwargs={'split_blocks': True}.
e.g. pyarrow_additional_kwargs={'split_blocks': True}.

Returns
-------
Expand Down Expand Up @@ -178,7 +178,7 @@ def read_sql_query(
**_transaction_id(transaction_id=transaction_id, query_as_of_time=query_as_of_time, DatabaseName=database),
)
query_id: str = client_lakeformation.start_query_planning(QueryString=sql, QueryPlanningContext=args)["QueryId"]
arrow_kwargs = _data_types.pyarrow2pandas_defaults(use_threads=use_threads, kwargs=arrow_additional_kwargs)
arrow_kwargs = _data_types.pyarrow2pandas_defaults(use_threads=use_threads, kwargs=pyarrow_additional_kwargs)
df = _resolve_sql_query(
query_id=query_id,
use_threads=use_threads,
Expand All @@ -199,7 +199,7 @@ def read_sql_table(
catalog_id: Optional[str] = None,
use_threads: bool = True,
boto3_session: Optional[boto3.Session] = None,
arrow_additional_kwargs: Optional[Dict[str, Any]] = None,
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
"""Extract all rows from AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame.

Expand Down Expand Up @@ -232,10 +232,10 @@ def read_sql_table(
When enabled, os.cpu_count() is used as the max number of threads.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session is used if boto3_session receives None.
arrow_additional_kwargs : Dict[str, Any], optional
pyarrow_additional_kwargs : Dict[str, Any], optional
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas dataframe.
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
e.g. arrow_additional_kwargs={'split_blocks': True}.
e.g. pyarrow_additional_kwargs={'split_blocks': True}.

Returns
-------
Expand Down Expand Up @@ -276,5 +276,5 @@ def read_sql_table(
catalog_id=catalog_id,
use_threads=use_threads,
boto3_session=boto3_session,
arrow_additional_kwargs=arrow_additional_kwargs,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
)
Loading