Skip to content

Commit dd1a0dc

Browse files
committed
Simplify the refactoring
1 parent 41c79fc commit dd1a0dc

File tree

9 files changed

+229
-226
lines changed

9 files changed

+229
-226
lines changed

awswrangler/_arrow.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,50 @@
33
import datetime
44
import json
55
import logging
6-
from typing import Any, Dict
6+
from typing import Any, Dict, Optional, Tuple, cast
77

88
import pandas as pd
99
import pyarrow as pa
1010

1111
_logger: logging.Logger = logging.getLogger(__name__)
1212

1313

14+
def _extract_partitions_from_path(path_root: str, path: str) -> Dict[str, str]:
15+
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
16+
if path_root not in path:
17+
raise Exception(f"Object {path} is not under the root path ({path_root}).")
18+
path_wo_filename: str = path.rpartition("/")[0] + "/"
19+
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "")
20+
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") == 1))
21+
if not dirs:
22+
return {}
23+
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=")[:2]) for x in dirs))
24+
values_dics: Dict[str, str] = dict(values_tups)
25+
return values_dics
26+
27+
28+
def _add_table_partitions(
29+
table: pa.Table,
30+
path: str,
31+
path_root: Optional[str],
32+
) -> pa.Table:
33+
part = _extract_partitions_from_path(path_root, f"s3://{path}") if path_root else None
34+
if part:
35+
for col, value in part.items():
36+
try:
37+
table = table.set_column(
38+
table.schema.get_field_index(col),
39+
col,
40+
pa.array([value] * len(table)).dictionary_encode(),
41+
)
42+
except pa.ArrowInvalid:
43+
table = table.append_column(
44+
col,
45+
pa.array([value] * len(table)).dictionary_encode(),
46+
)
47+
return table
48+
49+
1450
def _apply_timezone(df: pd.DataFrame, metadata: Dict[str, Any]) -> pd.DataFrame:
1551
for c in metadata["columns"]:
1652
if "field_name" in c and c["field_name"] is not None:

awswrangler/_threading.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Threading Module (PRIVATE)."""
22

33
import concurrent.futures
4-
import inspect
54
import itertools
65
import logging
76
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
@@ -32,14 +31,10 @@ def __init__(self, use_threads: Union[bool, int]):
3231
def map(self, func: Callable[..., List[str]], boto3_session: boto3.Session, *iterables: Any) -> List[Any]:
3332
"""Map iterables to multi-threaded function."""
3433
_logger.debug("Map: %s", func)
35-
first_arg = tuple(inspect.signature(func).parameters.keys())[0]
3634
if self._exec is not None:
37-
args = iterables
38-
if first_arg == "boto3_session":
39-
# Deserialize boto3 session into pickable object
40-
boto3_primitives = _utils.boto3_to_primitives(boto3_session=boto3_session)
41-
args = (itertools.repeat(boto3_primitives), *iterables)
35+
# Deserialize boto3 session into pickable object
36+
boto3_primitives = _utils.boto3_to_primitives(boto3_session=boto3_session)
37+
args = (itertools.repeat(boto3_primitives), *iterables)
4238
return list(self._exec.map(func, *args))
4339
# Single-threaded
44-
args = (itertools.repeat(boto3_session), *iterables) if first_arg == "boto3_session" else iterables
45-
return list(map(func, *args)) # type: ignore
40+
return list(map(func, *(itertools.repeat(boto3_session), *iterables))) # type: ignore

awswrangler/_utils.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import random
99
import time
1010
from concurrent.futures import FIRST_COMPLETED, Future, wait
11-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union, cast
11+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast
1212

1313
import boto3
1414
import botocore.config
1515
import numpy as np
1616
import pandas as pd
1717
import pyarrow as pa
18-
import pyarrow.parquet
1918
from pyarrow import fs
2019

2120
from awswrangler import _config, exceptions
@@ -159,15 +158,6 @@ def resolve_filesystem(session: Optional[boto3.Session] = None) -> fs.FileSystem
159158
)
160159

161160

162-
def resolve_filesystem_paths(paths: List[str]) -> List[str]:
163-
"""Resolve and normalize provided paths based on PyArrow filesystem."""
164-
resolved_paths = []
165-
for path in paths:
166-
_, resolved_path = fs._resolve_filesystem_and_path(path=path) # pylint: disable=protected-access
167-
resolved_paths.append(resolved_path)
168-
return resolved_paths
169-
170-
171161
def parse_path(path: str) -> Tuple[str, str]:
172162
"""Split a full S3 path in bucket and key strings.
173163
@@ -462,78 +452,3 @@ def list_to_arrow_table(
462452
arrays.append(v)
463453
# Will raise if metadata is not None
464454
return pa.Table.from_arrays(arrays, schema=schema, metadata=metadata)
465-
466-
467-
def _extract_partitions_from_path(path_root: str, path: str) -> Dict[str, str]:
468-
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
469-
if path_root not in path:
470-
raise exceptions.InvalidArgumentValue(f"Object {path} is not under the root path ({path_root}).")
471-
path_wo_filename: str = path.rpartition("/")[0] + "/"
472-
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "")
473-
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") == 1))
474-
if not dirs:
475-
return {}
476-
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=")[:2]) for x in dirs))
477-
values_dics: Dict[str, str] = dict(values_tups)
478-
return values_dics
479-
480-
481-
def _add_partitions_table(
482-
table: pa.Table,
483-
path: str,
484-
path_root: Optional[str],
485-
) -> pa.Table:
486-
part = _extract_partitions_from_path(path_root, f"s3://{path}") if path_root else None
487-
if part:
488-
for col, value in part.items():
489-
table = table.set_column(
490-
table.schema.get_field_index(col),
491-
col,
492-
pa.array([value] * len(table)).dictionary_encode(),
493-
)
494-
return table
495-
496-
497-
def piece_to_table(
498-
piece: pyarrow.parquet.ParquetDataset.pieces,
499-
schema: pa.schema,
500-
columns: Optional[List[str]],
501-
path_root: Optional[str],
502-
use_threads: Union[bool, int],
503-
) -> pa.Table:
504-
"""Create PyArrow Table from list of ParquetDataset batches."""
505-
return _add_partitions_table(
506-
table=piece.to_table(use_threads=use_threads, schema=schema, columns=columns),
507-
path=piece.path,
508-
path_root=path_root,
509-
)
510-
511-
512-
def batches_to_table(
513-
pieces: pyarrow.parquet.ParquetDataset.pieces,
514-
schema: pa.schema,
515-
columns: Optional[List[str]],
516-
path_root: Optional[str],
517-
use_threads: Union[bool, int],
518-
batch_size: Optional[int],
519-
) -> Iterator[pa.Table]:
520-
"""Yield PyArrow Tables from list of ParquetDataset pieces."""
521-
batch_kwargs = {
522-
"use_threads": use_threads,
523-
"columns": columns,
524-
"schema": schema,
525-
}
526-
if batch_size:
527-
batch_kwargs["batch_size"] = batch_size
528-
529-
for piece in pieces:
530-
batches = piece.to_batches(**batch_kwargs)
531-
for batch in batches:
532-
table = _add_partitions_table(
533-
table=pa.Table.from_batches([batch], schema=schema),
534-
path=piece.path,
535-
path_root=path_root,
536-
)
537-
# If the table is empty, drop it.
538-
if table.num_rows > 0:
539-
yield table

awswrangler/distributed/_distributed.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Distributed Module (PRIVATE)."""
22

3-
import inspect
43
import multiprocessing
54
import os
65
import sys
@@ -48,9 +47,6 @@ def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
4847
if config.distributed:
4948

5049
def wrapper(*args: Any, **kwargs: Any) -> Any:
51-
first_arg = tuple(inspect.signature(function).parameters.keys())[0]
52-
if first_arg != "boto3_session":
53-
args = args[1:]
5450
return ray.remote(function).remote(*args, **kwargs)
5551

5652
return wrapper

awswrangler/distributed/datasources/parquet_datasource.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,22 @@
55

66
import numpy as np
77
import pyarrow as pa
8+
import pyarrow.parquet as pq
89

910
# fs required to implicitly trigger S3 subsystem initialization
1011
import pyarrow.fs # noqa: F401 pylint: disable=unused-import
1112
from ray import cloudpickle
1213
from ray.data.context import DatasetContext
1314
from ray.data.datasource.datasource import ReadTask
15+
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem
1416
from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider
1517
from ray.data.datasource.parquet_datasource import (
1618
_deregister_parquet_file_fragment_serialization,
1719
_register_parquet_file_fragment_serialization,
1820
)
1921
from ray.data.impl.output_buffer import BlockOutputBuffer
2022

21-
from awswrangler._utils import batches_to_table
23+
from awswrangler._arrow import _add_table_partitions
2224

2325
_logger: logging.Logger = logging.getLogger(__name__)
2426

@@ -34,15 +36,27 @@ def prepare_read(
3436
self,
3537
parallelism: int,
3638
use_threads: Union[bool, int],
37-
parquet_dataset: pa.parquet.ParquetDataset,
38-
schema: pa.Schema,
39+
filesystem: "pyarrow.fs.FileSystem",
40+
paths: Union[str, List[str]],
41+
schema: "pyarrow.lib.Schema",
3942
columns: Optional[List[str]] = None,
43+
coerce_int96_timestamp_unit: Optional[str] = None,
4044
path_root: Optional[str] = None,
4145
meta_provider: ParquetMetadataProvider = DefaultParquetMetadataProvider(),
4246
_block_udf: Optional[Callable[..., Any]] = None,
4347
) -> List[ReadTask]:
4448
"""Create and return read tasks for a Parquet file-based datasource."""
4549

50+
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
51+
52+
parquet_dataset = pq.ParquetDataset(
53+
path_or_paths=paths,
54+
filesystem=filesystem,
55+
partitioning=None,
56+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
57+
use_legacy_dataset=False,
58+
)
59+
4660
def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
4761
# Deserialize after loading the filesystem class.
4862
try:
@@ -58,18 +72,24 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
5872
output_buffer = BlockOutputBuffer(block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size)
5973

6074
_logger.debug("Reading %s parquet pieces", len(pieces))
61-
tables = batches_to_table(
62-
pieces=pieces,
63-
schema=schema,
64-
columns=columns,
65-
path_root=path_root,
66-
use_threads=use_threads,
67-
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
68-
)
69-
for table in tables:
70-
output_buffer.add_block(table)
71-
if output_buffer.has_next():
72-
yield output_buffer.next()
75+
for piece in pieces:
76+
batches = piece.to_batches(
77+
use_threads=use_threads,
78+
columns=columns,
79+
schema=schema,
80+
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
81+
)
82+
for batch in batches:
83+
table = _add_table_partitions(
84+
table=pa.Table.from_batches([batch], schema=schema),
85+
path=piece.path,
86+
path_root=path_root,
87+
)
88+
# If the table is empty, drop it.
89+
if table.num_rows > 0:
90+
output_buffer.add_block(table)
91+
if output_buffer.has_next():
92+
yield output_buffer.next()
7393

7494
output_buffer.finalize()
7595
if output_buffer.has_next():

awswrangler/s3/_read.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from pandas.api.types import union_categoricals
1111

1212
from awswrangler import exceptions
13-
from awswrangler._utils import _extract_partitions_from_path, boto3_to_primitives, ensure_cpu_count
13+
from awswrangler._arrow import _extract_partitions_from_path
14+
from awswrangler._utils import boto3_to_primitives, ensure_cpu_count
1415
from awswrangler.s3._list import _prefix_cleanup
1516

1617
_logger: logging.Logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)