Skip to content

Commit c511d30

Browse files
committed
Add source func
1 parent 76ac88e commit c511d30

File tree

5 files changed

+12
-8
lines changed

5 files changed

+12
-8
lines changed

awswrangler/_distributed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def register_func(cls, source_func: Callable[..., Any], destination_func: Callab
7979
@classmethod
8080
def dispatch_on_engine(cls, func: Callable[..., Any]) -> Callable[..., Any]:
8181
"""Dispatch on engine function decorator."""
82+
8283
@wraps(func)
8384
def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any:
8485
return cls.dispatch_func(func)(*args, **kw)
86+
8587
# Save the original function
8688
wrapper._source_func = func # type: ignore
8789
return wrapper

awswrangler/distributed/ray/_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
6060
Callable[..., Any]
6161
"""
6262
# Access the source function if it exists
63-
function = getattr(function, '_source_func', function)
63+
function = getattr(function, "_source_func", function)
6464

6565
@wraps(function)
6666
def wrapper(*args: Any, **kwargs: Any) -> Any:
6767
return ray.remote(function).remote(*args, **kwargs) # type: ignore
68+
6869
return wrapper
6970

7071

awswrangler/distributed/ray/_register.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from awswrangler.s3._delete import _delete_objects
99
from awswrangler.s3._read_parquet import _read_parquet, _read_parquet_metadata_file
1010
from awswrangler.s3._read_text import _read_text
11-
from awswrangler.s3._select import _select_object_content
11+
from awswrangler.s3._select import _select_object_content, _select_query
1212
from awswrangler.s3._wait import _batch_paths, _wait_object_batch
1313
from awswrangler.s3._write_dataset import _to_buckets, _to_partitions
1414
from awswrangler.s3._write_parquet import _to_parquet
@@ -23,6 +23,7 @@ def register_ray() -> None:
2323
# S3
2424
engine.register_func(_delete_objects, ray_remote(_delete_objects))
2525
engine.register_func(_read_parquet_metadata_file, ray_remote(_read_parquet_metadata_file))
26+
engine.register_func(_select_query, ray_remote(_select_query))
2627
engine.register_func(_select_object_content, ray_remote(_select_object_content))
2728
engine.register_func(_wait_object_batch, ray_remote(_wait_object_batch))
2829
engine.register_func(_batch_paths, _batch_paths_distributed)

awswrangler/distributed/ray/modin/s3/_write_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
import modin.pandas as pd
55
import ray
6-
from modin.pandas import DataFrame as ModinDataFrame
7-
from pandas import DataFrame as PandasDataFrame
86

7+
from awswrangler._distributed import engine
98
from awswrangler.distributed.ray import ray_get, ray_remote
109
from awswrangler.s3._write_concurrent import _WriteProxy
1110
from awswrangler.s3._write_dataset import _delete_objects, _get_bucketing_series, _retrieve_paths, _to_partitions
@@ -29,7 +28,7 @@ def _to_buckets_distributed( # pylint: disable=unused-argument
2928
paths: List[str] = []
3029

3130
df_paths = df_groups.apply(
32-
func.dispatch(ModinDataFrame), # type: ignore
31+
engine.dispatch_func(func), # type: ignore
3332
path_root=path_root,
3433
filename_prefix=filename_prefix,
3534
boto3_session=None,
@@ -126,7 +125,8 @@ def _to_partitions_distributed( # pylint: disable=unused-argument
126125
if not bucketing_info:
127126
# If only partitioning (without bucketing), avoid expensive modin groupby
128127
# by partitioning and writing each block as an ordinary Pandas DataFrame
129-
_to_partitions_func = _to_partitions.dispatch(PandasDataFrame)
128+
_to_partitions_func = _to_partitions._source_func # type: ignore
129+
func = func._source_func # type: ignore
130130

131131
@ray_remote
132132
def write_partitions(df: pd.DataFrame) -> Tuple[List[str], Dict[str, List[str]]]:

awswrangler/s3/_read_parquet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from awswrangler._threading import _get_executor
2020
from awswrangler.catalog._get import _get_partitions
2121
from awswrangler.catalog._utils import _catalog_id
22-
from awswrangler.distributed.ray import RayLogger, ray_get, ray_remote
22+
from awswrangler.distributed.ray import RayLogger, ray_get
2323
from awswrangler.s3._fs import open_s3_object
2424
from awswrangler.s3._list import _path2list
2525
from awswrangler.s3._read import (
@@ -60,7 +60,7 @@ def _pyarrow_parquet_file_wrapper(
6060
raise
6161

6262

63-
@ray_remote
63+
@engine.dispatch_on_engine
6464
def _read_parquet_metadata_file(
6565
boto3_session: boto3.Session,
6666
path: str,

0 commit comments

Comments
 (0)