Skip to content

Commit 04956d9

Browse files
Add bulk_read option for reading large amounts of Parquet files quickly (#2033)
1 parent 40d9665 commit 04956d9

File tree

7 files changed

+241
-45
lines changed

7 files changed

+241
-45
lines changed

awswrangler/distributed/ray/datasources/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from awswrangler.distributed.ray.datasources.arrow_csv_datasource import ArrowCSVDatasource
44
from awswrangler.distributed.ray.datasources.arrow_json_datasource import ArrowJSONDatasource
5+
from awswrangler.distributed.ray.datasources.arrow_parquet_base_datasource import ArrowParquetBaseDatasource
56
from awswrangler.distributed.ray.datasources.arrow_parquet_datasource import ArrowParquetDatasource
67
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import UserProvidedKeyBlockWritePathProvider
78
from awswrangler.distributed.ray.datasources.pandas_text_datasource import (
@@ -14,6 +15,7 @@
1415
__all__ = [
1516
"ArrowCSVDatasource",
1617
"ArrowJSONDatasource",
18+
"ArrowParquetBaseDatasource",
1719
"ArrowParquetDatasource",
1820
"PandasCSVDataSource",
1921
"PandasFWFDataSource",
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Ray ParquetBaseDatasource Module.
2+
3+
This module is pulled from Ray's [ParquetBaseDatasource]
4+
(https://github.com/ray-project/ray/blob/master/python/ray/data/datasource/parquet_base_datasource.py) with a few changes
5+
and customized to ensure compatibility with AWS SDK for pandas behavior. Changes from the original implementation,
6+
are documented in the comments and marked with (AWS SDK for pandas) prefix.
7+
"""
8+
9+
from typing import Any, Dict, List, Optional
10+
11+
# fs required to implicitly trigger S3 subsystem initialization
12+
import pyarrow as pa
13+
import pyarrow.fs
14+
import pyarrow.parquet as pq
15+
from ray.data.block import BlockAccessor
16+
17+
from awswrangler._arrow import _add_table_partitions, _df_to_table
18+
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource
19+
20+
21+
class ArrowParquetBaseDatasource(PandasFileBasedDatasource): # pylint: disable=abstract-method
22+
"""(AWS SDK for pandas) Parquet datasource, for reading and writing Parquet files.
23+
24+
The following are the changes to the original Ray implementation:
25+
1. Added handling of additional parameters `dtype`, `index`, `compression` and added the ability
26+
to pass through additional `pyarrow_additional_kwargs` and `s3_additional_kwargs` for writes.
27+
3. Added `dataset` and `path_root` parameters to allow user to control loading partitions
28+
relative to the root S3 prefix.
29+
"""
30+
31+
_FILE_EXTENSION = "parquet"
32+
33+
def _read_file( # type: ignore[override]
34+
self,
35+
f: pa.NativeFile,
36+
path: str,
37+
path_root: str,
38+
**reader_args: Any,
39+
) -> pa.Table:
40+
use_threads: bool = reader_args.get("use_threads", False)
41+
columns: Optional[List[str]] = reader_args.get("columns", None)
42+
43+
dataset_kwargs = reader_args.get("dataset_kwargs", {})
44+
coerce_int96_timestamp_unit: Optional[str] = dataset_kwargs.get("coerce_int96_timestamp_unit", None)
45+
46+
table = pq.read_table(
47+
f,
48+
use_threads=use_threads,
49+
columns=columns,
50+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
51+
)
52+
53+
table = _add_table_partitions(
54+
table=table,
55+
path=f"s3://{path}",
56+
path_root=path_root,
57+
)
58+
59+
return table
60+
61+
def _open_input_source(
62+
self,
63+
filesystem: pyarrow.fs.FileSystem,
64+
path: str,
65+
**open_args: Any,
66+
) -> pa.NativeFile:
67+
# Parquet requires `open_input_file` due to random access reads
68+
return filesystem.open_input_file(path, **open_args)
69+
70+
def _write_block( # type: ignore[override]
71+
self,
72+
f: pa.NativeFile,
73+
block: BlockAccessor[Any],
74+
**writer_args: Any,
75+
) -> None:
76+
schema: Optional[pa.schema] = writer_args.get("schema", None)
77+
dtype: Optional[Dict[str, str]] = writer_args.get("dtype", None)
78+
index: bool = writer_args.get("index", False)
79+
compression: Optional[str] = writer_args.get("compression", None)
80+
pyarrow_additional_kwargs: Dict[str, Any] = writer_args.get("pyarrow_additional_kwargs", {})
81+
82+
pq.write_table(
83+
_df_to_table(block.to_pandas(), schema=schema, index=index, dtype=dtype),
84+
f,
85+
compression=compression,
86+
**pyarrow_additional_kwargs,
87+
)

awswrangler/distributed/ray/datasources/arrow_parquet_datasource.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from awswrangler._arrow import _add_table_partitions, _df_to_table
3434
from awswrangler.distributed.ray import ray_remote
35-
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource
35+
from awswrangler.distributed.ray.datasources.arrow_parquet_base_datasource import ArrowParquetBaseDatasource
3636
from awswrangler.s3._write import _COMPRESSION_2_EXT
3737

3838
_logger: logging.Logger = logging.getLogger(__name__)
@@ -72,7 +72,7 @@
7272
PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 5
7373

7474

75-
class ArrowParquetDatasource(PandasFileBasedDatasource): # pylint: disable=abstract-method
75+
class ArrowParquetDatasource(ArrowParquetBaseDatasource): # pylint: disable=abstract-method
7676
"""(AWS SDK for pandas) Parquet datasource, for reading and writing Parquet files.
7777
7878
The following are the changes to the original Ray implementation:
@@ -82,8 +82,6 @@ class ArrowParquetDatasource(PandasFileBasedDatasource): # pylint: disable=abst
8282
relative to the root S3 prefix.
8383
"""
8484

85-
_FILE_EXTENSION = "parquet"
86-
8785
def create_reader(self, **kwargs: Dict[str, Any]) -> Reader[Any]:
8886
"""Return a Reader for the given read arguments."""
8987
return _ArrowParquetDatasourceReader(**kwargs) # type: ignore[arg-type]
@@ -349,11 +347,6 @@ def _read_pieces(
349347
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
350348
serialized_pieces: List[_SerializedPiece],
351349
) -> Iterator["pyarrow.Table"]:
352-
# This import is necessary to load the tensor extension type.
353-
from ray.data.extensions.tensor_extension import ( # type: ignore[attr-defined] # noqa: F401, E501 # pylint: disable=import-outside-toplevel, unused-import
354-
ArrowTensorType,
355-
)
356-
357350
# Deserialize after loading the filesystem class.
358351
pieces: List[ParquetFileFragment] = _deserialize_pieces_with_retry(serialized_pieces)
359352

awswrangler/distributed/ray/modin/s3/_read_parquet.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,31 @@
44
import modin.pandas as pd
55
import pyarrow as pa
66
from ray.data import read_datasource
7+
from ray.data.datasource import FastFileMetadataProvider
8+
from ray.exceptions import RayTaskError
79

8-
from awswrangler.distributed.ray.datasources import ArrowParquetDatasource
10+
from awswrangler.distributed.ray.datasources import ArrowParquetBaseDatasource, ArrowParquetDatasource
911
from awswrangler.distributed.ray.modin._utils import _to_modin
1012

1113
if TYPE_CHECKING:
1214
from mypy_boto3_s3 import S3Client
1315

1416

17+
def _resolve_datasource_parameters(bulk_read: bool) -> Dict[str, Any]:
18+
if bulk_read:
19+
return {
20+
"datasource": ArrowParquetBaseDatasource(),
21+
"meta_provider": FastFileMetadataProvider(),
22+
}
23+
return {
24+
"datasource": ArrowParquetDatasource(),
25+
}
26+
27+
1528
def _read_parquet_distributed( # pylint: disable=unused-argument
1629
paths: List[str],
1730
path_root: Optional[str],
18-
schema: "pa.schema",
31+
schema: Optional[pa.schema],
1932
columns: Optional[List[str]],
2033
coerce_int96_timestamp_unit: Optional[str],
2134
use_threads: Union[bool, int],
@@ -24,18 +37,23 @@ def _read_parquet_distributed( # pylint: disable=unused-argument
2437
s3_client: Optional["S3Client"],
2538
s3_additional_kwargs: Optional[Dict[str, Any]],
2639
arrow_kwargs: Dict[str, Any],
40+
bulk_read: bool,
2741
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
2842
dataset_kwargs = {}
2943
if coerce_int96_timestamp_unit:
3044
dataset_kwargs["coerce_int96_timestamp_unit"] = coerce_int96_timestamp_unit
31-
dataset = read_datasource(
32-
datasource=ArrowParquetDatasource(),
33-
parallelism=parallelism,
34-
use_threads=use_threads,
35-
paths=paths,
36-
schema=schema,
37-
columns=columns,
38-
dataset_kwargs=dataset_kwargs,
39-
path_root=path_root,
40-
)
41-
return _to_modin(dataset=dataset, to_pandas_kwargs=arrow_kwargs, ignore_index=bool(path_root))
45+
46+
try:
47+
dataset = read_datasource(
48+
**_resolve_datasource_parameters(bulk_read),
49+
parallelism=parallelism,
50+
use_threads=use_threads,
51+
paths=paths,
52+
schema=schema,
53+
columns=columns,
54+
path_root=path_root,
55+
dataset_kwargs=dataset_kwargs,
56+
)
57+
return _to_modin(dataset=dataset, to_pandas_kwargs=arrow_kwargs, ignore_index=bool(path_root))
58+
except RayTaskError as e:
59+
raise e.cause

0 commit comments

Comments
 (0)