-
Notifications
You must be signed in to change notification settings - Fork 706
(feat): Refactor to distribute s3.read_parquet #1513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
092f14a
90b2eea
d89a584
8413d4e
1d3fdad
23410dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Arrow Utilities Module (PRIVATE).""" | ||
|
||
import datetime | ||
import json | ||
import logging | ||
from typing import Any, Dict, Optional, Tuple, cast | ||
|
||
import pandas as pd | ||
import pyarrow as pa | ||
|
||
_logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
def _extract_partitions_from_path(path_root: str, path: str) -> Dict[str, str]: | ||
path_root = path_root if path_root.endswith("/") else f"{path_root}/" | ||
if path_root not in path: | ||
raise Exception(f"Object {path} is not under the root path ({path_root}).") | ||
path_wo_filename: str = path.rpartition("/")[0] + "/" | ||
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "") | ||
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") == 1)) | ||
if not dirs: | ||
return {} | ||
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=")[:2]) for x in dirs)) | ||
values_dics: Dict[str, str] = dict(values_tups) | ||
return values_dics | ||
|
||
|
||
def _add_table_partitions( | ||
table: pa.Table, | ||
path: str, | ||
path_root: Optional[str], | ||
) -> pa.Table: | ||
part = _extract_partitions_from_path(path_root, path) if path_root else None | ||
if part: | ||
for col, value in part.items(): | ||
part_value = pa.array([value] * len(table)).dictionary_encode() | ||
if col not in table.schema.names: | ||
table = table.append_column(col, part_value) | ||
else: | ||
table = table.set_column( | ||
table.schema.get_field_index(col), | ||
col, | ||
part_value, | ||
) | ||
return table | ||
|
||
|
||
def _apply_timezone(df: pd.DataFrame, metadata: Dict[str, Any]) -> pd.DataFrame: | ||
for c in metadata["columns"]: | ||
if "field_name" in c and c["field_name"] is not None: | ||
col_name = str(c["field_name"]) | ||
elif "name" in c and c["name"] is not None: | ||
col_name = str(c["name"]) | ||
else: | ||
continue | ||
if col_name in df.columns and c["pandas_type"] == "datetimetz": | ||
timezone: datetime.tzinfo = pa.lib.string_to_tzinfo(c["metadata"]["timezone"]) | ||
_logger.debug("applying timezone (%s) on column %s", timezone, col_name) | ||
if hasattr(df[col_name].dtype, "tz") is False: | ||
df[col_name] = df[col_name].dt.tz_localize(tz="UTC") | ||
df[col_name] = df[col_name].dt.tz_convert(tz=timezone) | ||
return df | ||
|
||
|
||
def _table_to_df( | ||
table: pa.Table, | ||
kwargs: Dict[str, Any], | ||
) -> pd.DataFrame: | ||
"""Convert a PyArrow table to a Pandas DataFrame and apply metadata. | ||
|
||
This method should be used across to codebase to ensure this conversion is consistent. | ||
""" | ||
metadata: Dict[str, Any] = {} | ||
if table.schema.metadata is not None and b"pandas" in table.schema.metadata: | ||
metadata = json.loads(table.schema.metadata[b"pandas"]) | ||
|
||
df = table.to_pandas(**kwargs) | ||
|
||
if metadata: | ||
_logger.debug("metadata: %s", metadata) | ||
df = _apply_timezone(df=df, metadata=metadata) | ||
return df |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
|
||
from awswrangler import _config, exceptions | ||
from awswrangler.__metadata__ import __version__ | ||
from awswrangler._arrow import _table_to_df | ||
from awswrangler._config import apply_configs, config | ||
|
||
if TYPE_CHECKING or config.distributed: | ||
|
@@ -416,7 +417,7 @@ def table_refs_to_df( | |
) -> pd.DataFrame: | ||
"""Build Pandas dataframe from list of PyArrow tables.""" | ||
if isinstance(tables[0], pa.Table): | ||
return ensure_df_is_mutable(pa.concat_tables(tables, promote=True).to_pandas(**kwargs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a change that I would like to discuss. The column manipulations in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 we probably shouldn't be doing that in distributed scenario |
||
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an example of using |
||
return _arrow_refs_to_df(arrow_refs=tables, kwargs=kwargs) # type: ignore | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,33 @@ | ||
"""Utilities Module for Distributed methods.""" | ||
|
||
from typing import Any, Callable, Dict, List, Optional | ||
from typing import Any, Callable, Dict, List | ||
|
||
import modin.pandas as pd | ||
import pyarrow as pa | ||
import ray | ||
from modin.distributed.dataframe.pandas.partitions import from_partitions | ||
from ray.data.impl.arrow_block import ArrowBlockAccessor | ||
from ray.data.impl.arrow_block import ArrowBlockAccessor, ArrowRow | ||
from ray.data.impl.remote_fn import cached_remote_fn | ||
|
||
from awswrangler._arrow import _table_to_df | ||
|
||
|
||
def _block_to_df( | ||
block: Any, | ||
kwargs: Dict[str, Any], | ||
dtype: Optional[Dict[str, str]] = None, | ||
) -> pa.Table: | ||
block = ArrowBlockAccessor.for_block(block) | ||
df = block._table.to_pandas(**kwargs) # pylint: disable=protected-access | ||
return df.astype(dtype=dtype) if dtype else df | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was added to feature-match with non-distributed version. Do you propose to handle this differently or just remove for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought so but then I could not find any other reference in the library. The only one I found was here. And even if there was one, I would move it inside this new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think this type conversion was done in a different way (probably using map_types when going from pyarrow table to a dataframe), but it wasn't available in distributed case so this was the only crude way to do it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right ok, but do you agree that it's now solved since we are using the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep |
||
return _table_to_df(table=block._table, kwargs=kwargs) # pylint: disable=protected-access | ||
|
||
|
||
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Dict[str, Any]) -> pd.DataFrame: | ||
ds = ray.data.from_arrow_refs(arrow_refs) | ||
def _to_modin(dataset: ray.data.Dataset[ArrowRow], kwargs: Dict[str, Any]) -> pd.DataFrame: | ||
block_to_df = cached_remote_fn(_block_to_df) | ||
return from_partitions( | ||
partitions=[block_to_df.remote(block=block, kwargs=kwargs) for block in ds.get_internal_block_refs()], | ||
partitions=[block_to_df.remote(block=block, kwargs=kwargs) for block in dataset.get_internal_block_refs()], | ||
axis=0, | ||
index=pd.RangeIndex(start=0, stop=ds.count()), | ||
index=pd.RangeIndex(start=0, stop=dataset.count()), | ||
) | ||
|
||
|
||
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Dict[str, Any]) -> pd.DataFrame: | ||
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), kwargs=kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Distributed Datasources Module.""" | ||
|
||
from awswrangler.distributed.datasources.parquet_datasource import ParquetDatasource | ||
|
||
__all__ = [ | ||
"ParquetDatasource", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Distributed ParquetDatasource Module.""" | ||
|
||
import logging | ||
from typing import Any, Callable, Iterator, List, Optional, Union | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
|
||
# fs required to implicitly trigger S3 subsystem initialization | ||
import pyarrow.fs # noqa: F401 pylint: disable=unused-import | ||
import pyarrow.parquet as pq | ||
from ray import cloudpickle | ||
from ray.data.context import DatasetContext | ||
from ray.data.datasource.datasource import ReadTask | ||
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem | ||
from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider | ||
from ray.data.datasource.parquet_datasource import ( | ||
_deregister_parquet_file_fragment_serialization, | ||
_register_parquet_file_fragment_serialization, | ||
) | ||
from ray.data.impl.output_buffer import BlockOutputBuffer | ||
|
||
from awswrangler._arrow import _add_table_partitions | ||
|
||
_logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
# The number of rows to read per batch. This is sized to generate 10MiB batches | ||
# for rows about 1KiB in size. | ||
PARQUET_READER_ROW_BATCH_SIZE = 100000 | ||
|
||
|
||
class ParquetDatasource: | ||
jaidisido marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Parquet datasource, for reading and writing Parquet files.""" | ||
|
||
# Original: https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/parquet_datasource.py | ||
def prepare_read( | ||
self, | ||
parallelism: int, | ||
use_threads: Union[bool, int], | ||
paths: Union[str, List[str]], | ||
schema: "pyarrow.lib.Schema", | ||
columns: Optional[List[str]] = None, | ||
coerce_int96_timestamp_unit: Optional[str] = None, | ||
path_root: Optional[str] = None, | ||
filesystem: Optional["pyarrow.fs.FileSystem"] = None, | ||
meta_provider: ParquetMetadataProvider = DefaultParquetMetadataProvider(), | ||
_block_udf: Optional[Callable[..., Any]] = None, | ||
) -> List[ReadTask]: | ||
"""Create and return read tasks for a Parquet file-based datasource.""" | ||
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) | ||
|
||
parquet_dataset = pq.ParquetDataset( | ||
path_or_paths=paths, | ||
filesystem=filesystem, | ||
partitioning=None, | ||
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit, | ||
use_legacy_dataset=False, | ||
) | ||
|
||
def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]: | ||
# Deserialize after loading the filesystem class. | ||
try: | ||
_register_parquet_file_fragment_serialization() # type: ignore | ||
pieces = cloudpickle.loads(serialized_pieces) | ||
finally: | ||
_deregister_parquet_file_fragment_serialization() # type: ignore | ||
|
||
# Ensure that we're reading at least one dataset fragment. | ||
assert len(pieces) > 0 | ||
|
||
ctx = DatasetContext.get_current() | ||
output_buffer = BlockOutputBuffer(block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size) | ||
|
||
_logger.debug("Reading %s parquet pieces", len(pieces)) | ||
for piece in pieces: | ||
batches = piece.to_batches( | ||
use_threads=use_threads, | ||
columns=columns, | ||
schema=schema, | ||
batch_size=PARQUET_READER_ROW_BATCH_SIZE, | ||
) | ||
for batch in batches: | ||
# Table creation is wrapped inside _add_table_partitions | ||
# to add columns with partition values when dataset=True | ||
table = _add_table_partitions( | ||
kukushking marked this conversation as resolved.
Show resolved
Hide resolved
|
||
table=pa.Table.from_batches([batch], schema=schema), | ||
path=f"s3://{piece.path}", | ||
path_root=path_root, | ||
) | ||
# If the table is empty, drop it. | ||
if table.num_rows > 0: | ||
output_buffer.add_block(table) | ||
if output_buffer.has_next(): | ||
yield output_buffer.next() | ||
|
||
output_buffer.finalize() | ||
if output_buffer.has_next(): | ||
yield output_buffer.next() | ||
|
||
if _block_udf is not None: | ||
# Try to infer dataset schema by passing dummy table through UDF. | ||
dummy_table = schema.empty_table() | ||
try: | ||
inferred_schema = _block_udf(dummy_table).schema | ||
inferred_schema = inferred_schema.with_metadata(schema.metadata) | ||
except Exception: # pylint: disable=broad-except | ||
_logger.debug( | ||
"Failed to infer schema of dataset by passing dummy table " | ||
"through UDF due to the following exception:", | ||
exc_info=True, | ||
) | ||
inferred_schema = schema | ||
else: | ||
inferred_schema = schema | ||
read_tasks = [] | ||
metadata = meta_provider.prefetch_file_metadata(parquet_dataset.pieces) or [] | ||
try: | ||
_register_parquet_file_fragment_serialization() # type: ignore | ||
for pieces, metadata in zip( # type: ignore | ||
np.array_split(parquet_dataset.pieces, parallelism), | ||
np.array_split(metadata, parallelism), | ||
): | ||
if len(pieces) <= 0: | ||
continue | ||
serialized_pieces = cloudpickle.dumps(pieces) # type: ignore | ||
input_files = [p.path for p in pieces] | ||
meta = meta_provider( | ||
input_files, | ||
inferred_schema, | ||
pieces=pieces, | ||
prefetched_metadata=metadata, | ||
) | ||
read_tasks.append(ReadTask(lambda p=serialized_pieces: read_pieces(p), meta)) # type: ignore | ||
finally: | ||
_deregister_parquet_file_fragment_serialization() # type: ignore | ||
|
||
return read_tasks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This module serves two purposes:
distributed
module and otherawswrangler
modules. If these methods were to reside in the_utils
ors3/_read
module, they would cause circular import dependency issues