Skip to content

Commit 7e5b8e3

Browse files
committed
Major refactoring:
- Move distributed methods out of non-distributed modules - Refactor dispatching - Refactor structure of distributed modules - Add classes for execution engine and memory format
1 parent d218d79 commit 7e5b8e3

39 files changed

+929
-722
lines changed

awswrangler/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@
3131
timestream,
3232
)
3333
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
34-
from awswrangler._config import ExecutionEngine, config # noqa
35-
from awswrangler.distributed import initialize_ray
36-
37-
if config.execution_engine == ExecutionEngine.RAY.value:
38-
initialize_ray()
34+
from awswrangler._config import config # noqa
35+
from awswrangler._distributed import engine, memory_format # noqa
3936

37+
engine.initialize()
4038

4139
__all__ = [
4240
"athena",
@@ -60,6 +58,8 @@
6058
"secretsmanager",
6159
"sqlserver",
6260
"config",
61+
"engine",
62+
"memory_format",
6363
"timestream",
6464
"__description__",
6565
"__license__",

awswrangler/_config.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""Configuration file for AWS SDK for pandas."""
22

3-
import importlib.util
43
import inspect
54
import logging
65
import os
7-
from enum import Enum
86
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast
97

108
import botocore.config
@@ -18,20 +16,6 @@
1816
_ConfigValueType = Union[str, bool, int, botocore.config.Config, None]
1917

2018

21-
class ExecutionEngine(Enum):
22-
"""Execution engine enum."""
23-
24-
RAY = "ray"
25-
PYTHON = "python"
26-
27-
28-
class MemoryFormat(Enum):
29-
"""Memory format enum."""
30-
31-
MODIN = "modin"
32-
PANDAS = "pandas"
33-
34-
3519
class _ConfigArg(NamedTuple):
3620
dtype: Type[Union[str, bool, int, botocore.config.Config]]
3721
nullable: bool
@@ -69,12 +53,6 @@ class _ConfigArg(NamedTuple):
6953
"botocore_config": _ConfigArg(dtype=botocore.config.Config, nullable=True),
7054
"verify": _ConfigArg(dtype=str, nullable=True, loaded=True),
7155
# Distributed
72-
"execution_engine": _ConfigArg(
73-
dtype=str, nullable=False, loaded=True, default="ray" if importlib.util.find_spec("ray") else "python"
74-
),
75-
"memory_format": _ConfigArg(
76-
dtype=str, nullable=False, loaded=True, default="modin" if importlib.util.find_spec("modin") else "pandas"
77-
),
7856
"address": _ConfigArg(dtype=str, nullable=True),
7957
"redis_password": _ConfigArg(dtype=str, nullable=True),
8058
"ignore_reinit_error": _ConfigArg(dtype=bool, nullable=True),
@@ -440,24 +418,6 @@ def verify(self) -> Optional[str]:
440418
def verify(self, value: Optional[str]) -> None:
441419
self._set_config_value(key="verify", value=value)
442420

443-
@property
444-
def execution_engine(self) -> str:
445-
"""Property execution_engine."""
446-
return cast(str, self["execution_engine"])
447-
448-
@execution_engine.setter
449-
def execution_engine(self, value: str) -> None:
450-
self._set_config_value(key="execution_engine", value=value)
451-
452-
@property
453-
def memory_format(self) -> str:
454-
"""Property memory_format."""
455-
return cast(str, self["memory_format"])
456-
457-
@memory_format.setter
458-
def memory_format(self, value: str) -> None:
459-
self._set_config_value(key="memory_format", value=value)
460-
461421
@property
462422
def address(self) -> Optional[str]:
463423
"""Property address."""

awswrangler/_dispatch.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Dispatch engine and memory format configuration (PRIVATE)."""
2+
from typing import Any, Callable, Dict, Optional
3+
4+
from awswrangler._distributed import engine
5+
6+
7+
def dispatch_on_engine(func: Callable[..., Any]) -> Callable[..., Any]:
8+
"""Dispatch on engine function decorator.
9+
10+
Transforms a function into a dispatch function,
11+
which can have different behaviors based on the value of the distribution engine.
12+
"""
13+
registry: Dict[str, Callable[..., Any]] = {}
14+
15+
def dispatch(value: str) -> Callable[..., Any]:
16+
try:
17+
return registry[value]
18+
except KeyError:
19+
return func
20+
21+
def register(value: str, func: Optional[Callable[..., Any]] = None) -> Callable[..., Any]:
22+
if func is None:
23+
return lambda f: register(value, f)
24+
registry[value] = func
25+
return func
26+
27+
def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any:
28+
return dispatch(engine.get())(*args, **kw)
29+
30+
wrapper.register = register # type: ignore
31+
wrapper.dispatch = dispatch # type: ignore
32+
wrapper.registry = registry # type: ignore
33+
34+
return wrapper

awswrangler/_distributed.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Distributed engine and memory format configuration."""
2+
3+
# pylint: disable=import-outside-toplevel
4+
5+
import importlib.util
6+
from enum import Enum
7+
from typing import Optional
8+
9+
10+
class EngineEnum(Enum):
11+
"""Execution engine enum."""
12+
13+
RAY = "ray"
14+
PYTHON = "python"
15+
16+
17+
class MemoryFormatEnum(Enum):
18+
"""Memory format enum."""
19+
20+
MODIN = "modin"
21+
PANDAS = "pandas"
22+
23+
24+
class Engine:
25+
"""Execution engine configuration class."""
26+
27+
_name: Optional[str] = None
28+
29+
@classmethod
30+
def get_installed(cls) -> str:
31+
"""Get the installed distribution engine.
32+
33+
This is the engine that can be imported.
34+
35+
Returns
36+
-------
37+
str
38+
The distribution engine installed.
39+
"""
40+
if importlib.util.find_spec("ray"):
41+
return EngineEnum.RAY.value
42+
return EngineEnum.PYTHON.value
43+
44+
@classmethod
45+
def get(cls) -> str:
46+
"""Get the configured distribution engine.
47+
48+
This is the engine currently configured. If None, the installed engine is returned.
49+
50+
Returns
51+
-------
52+
str
53+
The distribution engine configured.
54+
"""
55+
return cls._name if cls._name else cls.get_installed()
56+
57+
@classmethod
58+
def set(cls, name: str) -> None:
59+
"""Set the distribution engine."""
60+
cls._name = name
61+
62+
@classmethod
63+
def register(cls, name: Optional[str] = None) -> None:
64+
"""Register the distribution engine dispatch methods."""
65+
engine_name = cls.get_installed() if not name else name
66+
67+
if engine_name == EngineEnum.RAY.value:
68+
from awswrangler.distributed.ray._register import register_ray
69+
70+
register_ray()
71+
72+
@classmethod
73+
def initialize(cls, name: Optional[str] = None) -> None:
74+
"""Initialize the distribution engine."""
75+
engine_name = cls.get_installed() if not name else name
76+
77+
if engine_name == EngineEnum.RAY.value:
78+
cls._name = EngineEnum.RAY.value
79+
80+
from awswrangler.distributed.ray import initialize_ray
81+
82+
initialize_ray()
83+
cls.register(cls._name)
84+
else:
85+
cls._name = EngineEnum.PYTHON.value
86+
87+
88+
class MemoryFormat:
89+
"""Memory format configuration class."""
90+
91+
_name: Optional[str] = None
92+
93+
@classmethod
94+
def get_installed(cls) -> str:
95+
"""Get the installed memory format.
96+
97+
This is the format that can be imported.
98+
99+
Returns
100+
-------
101+
str
102+
The memory format installed.
103+
"""
104+
if importlib.util.find_spec("modin"):
105+
return MemoryFormatEnum.MODIN.value
106+
return MemoryFormatEnum.PANDAS.value
107+
108+
@classmethod
109+
def get(cls) -> str:
110+
"""Get the configured memory format.
111+
112+
This is the memory format currently configured. If None, the installed memory format is returned.
113+
114+
Returns
115+
-------
116+
str
117+
The memory format configured.
118+
"""
119+
return cls._name if cls._name else cls.get_installed()
120+
121+
@classmethod
122+
def set(cls, name: str) -> None:
123+
"""Set the memory format."""
124+
cls._name = name
125+
126+
127+
engine: Engine = Engine()
128+
memory_format: MemoryFormat = MemoryFormat()

awswrangler/_threading.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,16 @@
33
import concurrent.futures
44
import itertools
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
6+
from typing import Any, Callable, List, Optional, Union
77

88
import boto3
99

1010
from awswrangler import _utils
11-
from awswrangler._config import ExecutionEngine, config
12-
13-
if config.execution_engine == ExecutionEngine.RAY.value or TYPE_CHECKING:
14-
from awswrangler.distributed.ray._pool import _RayPoolExecutor
11+
from awswrangler._distributed import EngineEnum, engine
1512

1613
_logger: logging.Logger = logging.getLogger(__name__)
1714

1815

19-
def _get_executor(use_threads: Union[bool, int]) -> Union["_ThreadPoolExecutor", "_RayPoolExecutor"]:
20-
return (
21-
_RayPoolExecutor()
22-
if config.execution_engine == ExecutionEngine.RAY.value
23-
else _ThreadPoolExecutor(use_threads) # type: ignore
24-
)
25-
26-
2716
class _ThreadPoolExecutor:
2817
def __init__(self, use_threads: Union[bool, int]):
2918
super().__init__()
@@ -42,3 +31,11 @@ def map(self, func: Callable[..., Any], boto3_session: boto3.Session, *iterables
4231
return list(self._exec.map(func, *args))
4332
# Single-threaded
4433
return list(map(func, *(itertools.repeat(boto3_session), *iterables))) # type: ignore
34+
35+
36+
def _get_executor(use_threads: Union[bool, int]) -> _ThreadPoolExecutor:
37+
if engine.get() == EngineEnum.RAY.value:
38+
from awswrangler.distributed.ray._pool import _RayPoolExecutor # pylint: disable=import-outside-toplevel
39+
40+
return _RayPoolExecutor() # type: ignore
41+
return _ThreadPoolExecutor(use_threads)

awswrangler/_utils.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import random
99
import time
1010
from concurrent.futures import FIRST_COMPLETED, Future, wait
11-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast
11+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast
1212

1313
import boto3
1414
import botocore.config
@@ -19,13 +19,8 @@
1919
from awswrangler import _config, exceptions
2020
from awswrangler.__metadata__ import __version__
2121
from awswrangler._arrow import _table_to_df
22-
from awswrangler._config import ExecutionEngine, MemoryFormat, apply_configs, config
23-
24-
if config.execution_engine == ExecutionEngine.RAY.value or TYPE_CHECKING:
25-
import ray # pylint: disable=unused-import
26-
27-
if config.memory_format == MemoryFormat.MODIN.value:
28-
from awswrangler.distributed.ray._utils import _arrow_refs_to_df # pylint: disable=ungrouped-imports
22+
from awswrangler._config import apply_configs
23+
from awswrangler._dispatch import dispatch_on_engine
2924

3025
_logger: logging.Logger = logging.getLogger(__name__)
3126

@@ -413,13 +408,10 @@ def check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Di
413408
)
414409

415410

416-
def table_refs_to_df(
417-
tables: Union[List[pa.Table], List["ray.ObjectRef"]], kwargs: Dict[str, Any] # type: ignore
418-
) -> pd.DataFrame:
411+
@dispatch_on_engine
412+
def table_refs_to_df(tables: List[pa.Table], kwargs: Dict[str, Any]) -> pd.DataFrame: # type: ignore
419413
"""Build Pandas dataframe from list of PyArrow tables."""
420-
if isinstance(tables[0], pa.Table):
421-
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
422-
return _arrow_refs_to_df(arrow_refs=tables, kwargs=kwargs) # type: ignore
414+
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
423415

424416

425417
def list_to_arrow_table(

awswrangler/distributed/__init__.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1 @@
11
"""Distributed Module."""
2-
3-
from awswrangler.distributed._distributed import ( # noqa
4-
RayLogger,
5-
initialize_ray,
6-
modin_repartition,
7-
ray_get,
8-
ray_remote,
9-
)
10-
11-
__all__ = [
12-
"RayLogger",
13-
"initialize_ray",
14-
"modin_repartition",
15-
"ray_get",
16-
"ray_remote",
17-
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
11
"""Ray Module."""
2+
3+
from awswrangler.distributed.ray._core import RayLogger, initialize_ray, ray_get, ray_remote # noqa
4+
5+
__all__ = [
6+
"RayLogger",
7+
"initialize_ray",
8+
"ray_get",
9+
"ray_remote",
10+
]

0 commit comments

Comments
 (0)