Skip to content

Commit 5c3537b

Browse files
committed
Refactor RayPoolExecutor to return futures and use ray.get in S3 Select
1 parent 013cac5 commit 5c3537b

File tree

6 files changed

+32
-18
lines changed

6 files changed

+32
-18
lines changed

awswrangler/_threading.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818

1919
def _get_executor(use_threads: Union[bool, int]) -> Union["_ThreadPoolExecutor", "_RayPoolExecutor"]:
20-
executor = _RayPoolExecutor if config.distributed else _ThreadPoolExecutor
21-
return executor(use_threads) # type: ignore
20+
return _RayPoolExecutor() if config.distributed else _ThreadPoolExecutor(use_threads) # type: ignore
2221

2322

2423
class _ThreadPoolExecutor:

awswrangler/_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
if TYPE_CHECKING or config.distributed:
2424
import ray # pylint: disable=unused-import
25-
2625
from awswrangler.distributed._utils import _arrow_refs_to_df # pylint: disable=ungrouped-imports
2726

2827
_logger: logging.Logger = logging.getLogger(__name__)

awswrangler/distributed/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Distributed Module."""
22

3-
from awswrangler.distributed._distributed import initialize_ray, ray_remote # noqa
3+
from awswrangler.distributed._distributed import initialize_ray, ray_get, ray_remote # noqa
44

55
__all__ = [
6-
"ray_remote",
76
"initialize_ray",
7+
"ray_get",
8+
"ray_remote",
89
]

awswrangler/distributed/_distributed.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import sys
66
import warnings
7-
from typing import TYPE_CHECKING, Any, Callable, Optional
7+
from typing import TYPE_CHECKING, Any, Callable, List, Optional
88

99
import psutil
1010

@@ -14,6 +14,24 @@
1414
import ray # pylint: disable=import-error
1515

1616

17+
def ray_get(futures: List[Any]) -> List[Any]:
18+
"""
19+
Run ray.get on futures if distributed.
20+
21+
Parameters
22+
----------
23+
futures : List[Any]
24+
List of Ray futures
25+
26+
Returns
27+
-------
28+
List[Any]
29+
"""
30+
if config.distributed:
31+
return ray.get(futures)
32+
return futures
33+
34+
1735
def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
1836
"""
1937
Decorate callable to wrap within ray.remote.

awswrangler/distributed/_pool.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@
22

33
import itertools
44
import logging
5-
from typing import Any, Callable, List, Optional, Union
5+
from typing import Any, Callable, List
66

77
import boto3
8-
from ray.util.multiprocessing import Pool
98

109
_logger: logging.Logger = logging.getLogger(__name__)
1110

1211

1312
class _RayPoolExecutor:
14-
def __init__(self, processes: Optional[Union[bool, int]] = None):
15-
self._exec: Pool = Pool(processes=None if isinstance(processes, bool) else processes)
13+
def __init__(self) -> None:
14+
pass
1615

1716
def map(self, func: Callable[..., List[str]], _: boto3.Session, *args: Any) -> List[Any]:
18-
"""Map function and its args to Ray pool."""
19-
futures = []
17+
"""Map func and return ray futures."""
2018
_logger.debug("Ray map: %s", func)
21-
# Discard boto3.Session object & call the fn asynchronously
22-
for arg in zip(itertools.repeat(None), *args):
23-
futures.append(self._exec.apply_async(func, arg))
24-
return [f.get() for f in futures]
19+
# Discard boto3.Session object & return futures
20+
return list(func(*arg) for arg in zip(itertools.repeat(None), *args))

awswrangler/s3/_select.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from awswrangler import _data_types, _utils, exceptions
1515
from awswrangler._threading import _get_executor
16-
from awswrangler.distributed import ray_remote
16+
from awswrangler.distributed import ray_get, ray_remote
1717
from awswrangler.s3._describe import size_objects
1818
from awswrangler.s3._list import _path2list
1919
from awswrangler.s3._read import _get_path_ignore_suffix
@@ -71,6 +71,7 @@ def _select_object_content(
7171
return _utils.list_to_arrow_table(mapping=payload_records)
7272

7373

74+
@ray_remote
7475
def _select_query(
7576
path: str,
7677
executor: Any,
@@ -272,5 +273,5 @@ def select_query(
272273

273274
arrow_kwargs = _data_types.pyarrow2pandas_defaults(use_threads=use_threads, kwargs=arrow_additional_kwargs)
274275
executor = _get_executor(use_threads=use_threads)
275-
tables = _flatten_list([_select_query(path=path, executor=executor, **select_kwargs) for path in paths])
276+
tables = _flatten_list(ray_get([_select_query(path=path, executor=executor, **select_kwargs) for path in paths]))
276277
return _utils.table_refs_to_df(tables=tables, kwargs=arrow_kwargs)

0 commit comments

Comments
 (0)