Skip to content

Commit fadfdf7

Browse files
committed
PR feedback and distributed module test
1 parent 9fd4093 commit fadfdf7

17 files changed

+104
-52
lines changed

awswrangler/_distributed.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44

55
import importlib.util
66
from collections import defaultdict
7-
from enum import Enum
7+
from enum import Enum, unique
88
from functools import wraps
99
from typing import Any, Callable, Dict, Optional
1010

1111

12+
@unique
1213
class EngineEnum(Enum):
1314
"""Execution engine enum."""
1415

1516
RAY = "ray"
1617
PYTHON = "python"
1718

1819

20+
@unique
1921
class MemoryFormatEnum(Enum):
2022
"""Memory format enum."""
2123

@@ -26,26 +28,26 @@ class MemoryFormatEnum(Enum):
2628
class Engine:
2729
"""Execution engine configuration class."""
2830

29-
_name: Optional[str] = None
31+
_enum: Optional[Enum] = None
3032
_registry: Dict[str, Dict[str, Callable[..., Any]]] = defaultdict(dict)
3133

3234
@classmethod
33-
def get_installed(cls) -> str:
35+
def get_installed(cls) -> Enum:
3436
"""Get the installed distribution engine.
3537
3638
This is the engine that can be imported.
3739
3840
Returns
3941
-------
40-
str
42+
EngineEnum
4143
The distribution engine installed.
4244
"""
4345
if importlib.util.find_spec("ray"):
44-
return EngineEnum.RAY.value
45-
return EngineEnum.PYTHON.value
46+
return EngineEnum.RAY
47+
return EngineEnum.PYTHON
4648

4749
@classmethod
48-
def get(cls) -> str:
50+
def get(cls) -> Enum:
4951
"""Get the configured distribution engine.
5052
5153
This is the engine currently configured. If None, the installed engine is returned.
@@ -55,25 +57,25 @@ def get(cls) -> str:
5557
str
5658
The distribution engine configured.
5759
"""
58-
return cls._name if cls._name else cls.get_installed()
60+
return cls._enum if cls._enum else cls.get_installed()
5961

6062
@classmethod
6163
def set(cls, name: str) -> None:
6264
"""Set the distribution engine."""
63-
cls._name = name
65+
cls._enum = EngineEnum._member_map_[name.upper()] # pylint: disable=protected-access,no-member
6466

6567
@classmethod
6668
def dispatch_func(cls, source_func: Callable[..., Any], value: Optional[Any] = None) -> Callable[..., Any]:
6769
"""Dispatch a func based on value or the distribution engine and the source function."""
6870
try:
69-
return cls._registry[value or cls.get()][source_func.__name__]
71+
return cls._registry[value or cls.get().value][source_func.__name__]
7072
except KeyError:
71-
return source_func
73+
return getattr(source_func, "_source_func", source_func)
7274

7375
@classmethod
7476
def register_func(cls, source_func: Callable[..., Any], destination_func: Callable[..., Any]) -> Callable[..., Any]:
7577
"""Register a func based on the distribution engine and source function."""
76-
cls._registry[cls.get()][source_func.__name__] = destination_func
78+
cls._registry[cls.get().value][source_func.__name__] = destination_func
7779
return destination_func
7880

7981
@classmethod
@@ -85,13 +87,13 @@ def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any:
8587
return cls.dispatch_func(func)(*args, **kw)
8688

8789
# Save the original function
88-
wrapper._source_func = func # type: ignore # pylint: pylint: disable=protected-access
90+
wrapper._source_func = func # type: ignore # pylint: disable=protected-access
8991
return wrapper
9092

9193
@classmethod
9294
def register(cls, name: Optional[str] = None) -> None:
9395
"""Register the distribution engine dispatch methods."""
94-
engine_name = cls.get_installed() if not name else name
96+
engine_name = cls.get_installed().value if not name else name
9597
cls.set(engine_name)
9698
cls._registry.clear()
9799

@@ -103,7 +105,7 @@ def register(cls, name: Optional[str] = None) -> None:
103105
@classmethod
104106
def initialize(cls, name: Optional[str] = None) -> None:
105107
"""Initialize the distribution engine."""
106-
engine_name = cls.get_installed() if not name else name
108+
engine_name = cls.get_installed().value if not name else name
107109
if engine_name == EngineEnum.RAY.value:
108110
from awswrangler.distributed.ray import initialize_ray
109111

@@ -114,40 +116,40 @@ def initialize(cls, name: Optional[str] = None) -> None:
114116
class MemoryFormat:
115117
"""Memory format configuration class."""
116118

117-
_name: Optional[str] = None
119+
_enum: Optional[Enum] = None
118120

119121
@classmethod
120-
def get_installed(cls) -> str:
122+
def get_installed(cls) -> Enum:
121123
"""Get the installed memory format.
122124
123125
This is the format that can be imported.
124126
125127
Returns
126128
-------
127-
str
129+
Enum
128130
The memory format installed.
129131
"""
130132
if importlib.util.find_spec("modin"):
131-
return MemoryFormatEnum.MODIN.value
132-
return MemoryFormatEnum.PANDAS.value
133+
return MemoryFormatEnum.MODIN
134+
return MemoryFormatEnum.PANDAS
133135

134136
@classmethod
135-
def get(cls) -> str:
137+
def get(cls) -> Enum:
136138
"""Get the configured memory format.
137139
138140
This is the memory format currently configured. If None, the installed memory format is returned.
139141
140142
Returns
141143
-------
142-
str
144+
Enum
143145
The memory format configured.
144146
"""
145-
return cls._name if cls._name else cls.get_installed()
147+
return cls._enum if cls._enum else cls.get_installed()
146148

147149
@classmethod
148150
def set(cls, name: str) -> None:
149151
"""Set the memory format."""
150-
cls._name = name
152+
cls._enum = MemoryFormatEnum._member_map_[name.upper()] # pylint: disable=protected-access,no-member
151153

152154

153155
engine: Engine = Engine()

awswrangler/_threading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def map(self, func: Callable[..., Any], boto3_session: boto3.Session, *iterables
3434

3535

3636
def _get_executor(use_threads: Union[bool, int]) -> _ThreadPoolExecutor:
37-
if engine.get() == EngineEnum.RAY.value:
37+
if engine.get() == EngineEnum.RAY:
3838
from awswrangler.distributed.ray._pool import _RayPoolExecutor # pylint: disable=import-outside-toplevel
3939

4040
return _RayPoolExecutor() # type: ignore

awswrangler/distributed/ray/_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from awswrangler._config import apply_configs
88
from awswrangler._distributed import EngineEnum, engine
99

10-
if engine.get() == EngineEnum.RAY.value or TYPE_CHECKING:
10+
if engine.get() == EngineEnum.RAY or TYPE_CHECKING:
1111
import ray
1212

1313
_logger: logging.Logger = logging.getLogger(__name__)
@@ -26,7 +26,7 @@ def __init__(
2626

2727
def get_logger(self, name: Union[str, Any] = None) -> Union[logging.Logger, Any]:
2828
"""Return logger object."""
29-
return logging.getLogger(name) if engine.get() == EngineEnum.RAY.value else None
29+
return logging.getLogger(name) if engine.get() == EngineEnum.RAY else None
3030

3131

3232
def ray_get(futures: List[Any]) -> List[Any]:
@@ -42,7 +42,7 @@ def ray_get(futures: List[Any]) -> List[Any]:
4242
-------
4343
List[Any]
4444
"""
45-
if engine.get() == EngineEnum.RAY.value:
45+
if engine.get() == EngineEnum.RAY:
4646
return ray.get(futures)
4747
return futures
4848

awswrangler/distributed/ray/_register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def register_ray() -> None:
2626
engine.register_func(_select_object_content, ray_remote(_select_object_content))
2727
engine.register_func(_wait_object_batch, ray_remote(_wait_object_batch))
2828

29-
if memory_format.get() == MemoryFormatEnum.MODIN.value:
29+
if memory_format.get() == MemoryFormatEnum.MODIN:
3030
from awswrangler.distributed.ray.modin._core import modin_repartition
3131
from awswrangler.distributed.ray.modin._utils import _arrow_refs_to_df
3232
from awswrangler.distributed.ray.modin.s3._read_parquet import _read_parquet_distributed

awswrangler/distributed/ray/modin/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def modin_repartition(function: Callable[..., Any]) -> Callable[..., Any]:
2424
-------
2525
Callable[..., Any]
2626
"""
27+
# Access the source function if it exists
2728
function = getattr(function, "_source_func", function)
2829

2930
@wraps(function)

awswrangler/distributed/ray/modin/s3/_write_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import boto3
55
import modin.pandas as pd
66
import ray
7+
from pandas import DataFrame as PandasDataFrame
78

89
from awswrangler._distributed import engine
910
from awswrangler.distributed.ray import ray_get, ray_remote
@@ -123,8 +124,8 @@ def _to_partitions_distributed( # pylint: disable=unused-argument
123124
if not bucketing_info:
124125
# If only partitioning (without bucketing), avoid expensive modin groupby
125126
# by partitioning and writing each block as an ordinary Pandas DataFrame
126-
_to_partitions_func = getattr(_to_partitions, "_source_func", _to_partitions)
127-
func = getattr(func, "_source_func", func)
127+
_to_partitions_func = engine.dispatch_func(_to_partitions, PandasDataFrame)
128+
func = engine.dispatch_func(func, PandasDataFrame)
128129

129130
@ray_remote
130131
def write_partitions(df: pd.DataFrame) -> Tuple[List[str], Dict[str, List[str]]]:

awswrangler/oracle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def detect_oracle_decimal_datatype(cursor: Any) -> Dict[str, pa.DataType]:
434434
if isinstance(cursor, oracledb.Cursor):
435435
# Oracle stores DECIMAL as the NUMBER type
436436
for row in cursor.description:
437-
if row[1] == oracledb.DB_TYPE_NUMBER and row[5] > 0:
437+
if row[1] == oracledb.DB_TYPE_NUMBER and row[5] > 0: # pylint: disable=no-member
438438
dtype[row[0]] = pa.decimal128(row[4], row[5])
439439

440440
_logger.debug("decimal dtypes: %s", dtype)

awswrangler/s3/_write.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Amazon CSV S3 Write Module (PRIVATE)."""
22

33
import logging
4+
from enum import Enum
45
from typing import Any, Dict, List, Optional, Tuple
56

67
import pandas as pd
@@ -56,14 +57,14 @@ def _validate_args(
5657
description: Optional[str],
5758
parameters: Optional[Dict[str, str]],
5859
columns_comments: Optional[Dict[str, str]],
59-
execution_engine: str,
60+
execution_engine: Enum,
6061
) -> None:
6162
if df.empty is True:
6263
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
6364
if dataset is False:
6465
if path is None:
6566
raise exceptions.InvalidArgumentValue("If dataset is False, the `path` argument must be passed.")
66-
if execution_engine == EngineEnum.PYTHON.value and path.endswith("/"):
67+
if execution_engine == EngineEnum.PYTHON and path.endswith("/"):
6768
raise exceptions.InvalidArgumentValue(
6869
"If <dataset=False>, the argument <path> should be a key, not a prefix."
6970
)

awswrangler/s3/_write_parquet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _to_parquet(
166166
use_threads: Union[bool, int],
167167
path: Optional[str] = None,
168168
path_root: Optional[str] = None,
169-
filename_prefix: Optional[str] = uuid.uuid4().hex,
169+
filename_prefix: Optional[str] = None,
170170
max_rows_by_file: Optional[int] = 0,
171171
) -> List[str]:
172172
file_path = _get_file_path(

awswrangler/s3/_write_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _to_text( # pylint: disable=unused-argument
3838
s3_additional_kwargs: Optional[Dict[str, str]],
3939
path: Optional[str] = None,
4040
path_root: Optional[str] = None,
41-
filename_prefix: Optional[str] = uuid.uuid4().hex,
41+
filename_prefix: Optional[str] = None,
4242
bucketing: bool = False,
4343
**pandas_kwargs: Any,
4444
) -> List[str]:

tests/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from awswrangler._distributed import EngineEnum, MemoryFormatEnum
1313
from awswrangler._utils import try_it
1414

15-
if wr.engine.get() == EngineEnum.RAY.value and wr.memory_format.get() == MemoryFormatEnum.MODIN.value:
15+
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
1616
import modin.pandas as pd
1717
else:
1818
import pandas as pd

tests/unit/test_distributed.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import importlib.util
2+
import logging
3+
from enum import Enum
4+
5+
import pytest
6+
7+
import awswrangler as wr
8+
from awswrangler._distributed import EngineEnum, MemoryFormatEnum
9+
from awswrangler.s3._write_parquet import _to_parquet
10+
11+
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
12+
13+
pytestmark = pytest.mark.distributed
14+
15+
16+
@pytest.mark.parametrize(
17+
"engine_enum",
18+
[
19+
pytest.param(
20+
EngineEnum.RAY,
21+
marks=pytest.mark.skip("ray not available") if not importlib.util.find_spec("ray") else [],
22+
),
23+
],
24+
)
25+
def test_engine(engine_enum: Enum) -> None:
26+
assert wr.engine.get_installed() == engine_enum
27+
assert wr.engine.get() == engine_enum
28+
assert wr.engine._registry
29+
assert wr.engine.dispatch_func(_to_parquet).__name__.endswith("distributed")
30+
assert not wr.engine.dispatch_func(_to_parquet, "python").__name__.endswith("distributed")
31+
32+
wr.engine.register("python")
33+
assert wr.engine.get_installed() == engine_enum
34+
assert wr.engine.get() == EngineEnum.PYTHON
35+
assert not wr.engine._registry
36+
assert not wr.engine.dispatch_func(_to_parquet).__name__.endswith("distributed")
37+
38+
39+
@pytest.mark.parametrize(
40+
"memory_format_enum",
41+
[
42+
pytest.param(
43+
MemoryFormatEnum.MODIN,
44+
marks=pytest.mark.skip("modin not available") if not importlib.util.find_spec("modin") else [],
45+
),
46+
],
47+
)
48+
def test_memory_format(memory_format_enum: Enum) -> None:
49+
assert wr.memory_format.get_installed() == memory_format_enum
50+
assert wr.memory_format.get() == memory_format_enum
51+
52+
wr.memory_format.set("pandas")
53+
assert wr.memory_format.get_installed() == memory_format_enum
54+
assert wr.memory_format.get() == MemoryFormatEnum.PANDAS

tests/unit/test_lakeformation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv
1111

12-
if wr.engine.get() == EngineEnum.RAY.value and wr.memory_format.get() == MemoryFormatEnum.MODIN.value:
12+
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
1313
import modin.pandas as pd
1414
else:
1515
import pandas as pd

tests/unit/test_s3_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import awswrangler as wr
66
from awswrangler._distributed import EngineEnum, MemoryFormatEnum
77

8-
if wr.engine.get() == EngineEnum.RAY.value and wr.memory_format.get() == MemoryFormatEnum.MODIN.value:
8+
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
99
import modin.pandas as pd
1010
else:
1111
import pandas as pd

tests/unit/test_s3_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import awswrangler as wr
77
from awswrangler._distributed import EngineEnum, MemoryFormatEnum
88

9-
if wr.engine.get() == EngineEnum.RAY.value and wr.memory_format.get() == MemoryFormatEnum.MODIN.value:
9+
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
1010
import modin.pandas as pd
1111
else:
1212
import pandas as pd

tests/unit/test_s3_text_compressed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .._utils import get_df_csv
1515

16-
if wr.engine.get() == EngineEnum.RAY.value and wr.memory_format.get() == MemoryFormatEnum.MODIN.value:
16+
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
1717
import modin.pandas as pd
1818
else:
1919
import pandas as pd

0 commit comments

Comments
 (0)