-
Notifications
You must be signed in to change notification settings - Fork 707
(perf): Distribute timestream write with executor #1715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from modin.distributed.dataframe.pandas import from_partitions | ||
from ray.data._internal.arrow_block import ArrowBlockAccessor, ArrowRow | ||
from ray.data._internal.remote_fn import cached_remote_fn | ||
from ray.types import ObjectRef | ||
|
||
from awswrangler import exceptions | ||
from awswrangler._arrow import _table_to_df | ||
|
@@ -43,6 +44,11 @@ def _to_modin( | |
) | ||
|
||
|
||
def _split_modin_frame(df: modin_pd.DataFrame, splits: int) -> List[ObjectRef[Any]]: # pylint: disable=unused-argument | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not 100% convinced that this is the best way to split a modin dataframe |
||
object_refs: List[ObjectRef[Any]] = ray.data.from_modin(df).get_internal_block_refs() | ||
return object_refs | ||
|
||
|
||
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Optional[Dict[str, Any]]) -> modin_pd.DataFrame: | ||
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), to_pandas_kwargs=kwargs) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
"""Amazon Timestream Module.""" | ||
|
||
import concurrent.futures | ||
import itertools | ||
import logging | ||
from datetime import datetime | ||
|
@@ -11,10 +10,17 @@ | |
from botocore.config import Config | ||
|
||
from awswrangler import _data_types, _utils | ||
from awswrangler._distributed import engine | ||
from awswrangler._threading import _get_executor | ||
from awswrangler.distributed.ray import ray_get | ||
|
||
_logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
def _flatten_list(elements: List[List[Any]]) -> List[Any]: | ||
return [item for sublist in elements for item in sublist] | ||
|
||
|
||
def _df2list(df: pd.DataFrame) -> List[List[Any]]: | ||
"""Extract Parameters.""" | ||
parameters: List[List[Any]] = df.values.tolist() | ||
|
@@ -27,17 +33,17 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]: | |
return parameters | ||
|
||
|
||
@engine.dispatch_on_engine | ||
def _write_batch( | ||
boto3_session: Optional[boto3.Session], | ||
database: str, | ||
table: str, | ||
cols_names: List[str], | ||
measure_cols_names: List[str], | ||
measure_types: List[str], | ||
version: int, | ||
batch: List[Any], | ||
boto3_primitives: _utils.Boto3PrimitivesType, | ||
) -> List[Dict[str, str]]: | ||
boto3_session: boto3.Session = _utils.boto3_from_primitives(primitives=boto3_primitives) | ||
client: boto3.client = _utils.client( | ||
service_name="timestream-write", | ||
session=boto3_session, | ||
|
@@ -85,6 +91,33 @@ def _write_batch( | |
return [] | ||
|
||
|
||
@engine.dispatch_on_engine | ||
def _write_df( | ||
df: pd.DataFrame, | ||
executor: Any, | ||
database: str, | ||
table: str, | ||
cols_names: List[str], | ||
measure_cols_names: List[str], | ||
measure_types: List[str], | ||
version: int, | ||
boto3_session: Optional[boto3.Session] = None, | ||
) -> List[Dict[str, str]]: | ||
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df), max_length=100) | ||
jaidisido marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the split modin dataframe block reference id is received. I assume modin/ray is smart enough to avoid a shuffle (i.e. pulling a block from one worker to another) and would instead run the remote functions ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The blocks would be broken down into batches and sent to workers so unfortunately some shuffle or rather copy will inevitably happen. One thing I'm afraid of is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point, the load test on 64,000 rows was fine but let me check with an even larger one tomorrow |
||
_logger.debug("len(batches): %s", len(batches)) | ||
return executor.map( # type: ignore | ||
_write_batch, | ||
boto3_session, | ||
itertools.repeat(database), | ||
itertools.repeat(table), | ||
itertools.repeat(cols_names), | ||
itertools.repeat(measure_cols_names), | ||
itertools.repeat(measure_types), | ||
itertools.repeat(version), | ||
batches, | ||
) | ||
|
||
|
||
def _cast_value(value: str, dtype: str) -> Any: # pylint: disable=too-many-branches,too-many-return-statements | ||
if dtype == "VARCHAR": | ||
return value | ||
|
@@ -173,14 +206,18 @@ def write( | |
measure_col: Union[str, List[str]], | ||
dimensions_cols: List[str], | ||
version: int = 1, | ||
num_threads: int = 32, | ||
use_threads: Union[bool, int] = True, | ||
boto3_session: Optional[boto3.Session] = None, | ||
) -> List[Dict[str, str]]: | ||
"""Store a Pandas DataFrame into a Amazon Timestream table. | ||
|
||
Note | ||
---- | ||
In case `use_threads=True`, the number of threads from os.cpu_count() is used. | ||
|
||
Parameters | ||
---------- | ||
df: pandas.DataFrame | ||
df : pandas.DataFrame | ||
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html | ||
database : str | ||
Amazon Timestream database name. | ||
|
@@ -195,8 +232,10 @@ def write( | |
version : int | ||
Version number used for upserts. | ||
Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html. | ||
num_threads : str | ||
Number of thread to be used for concurrent writing. | ||
use_threads : bool, int | ||
True to enable concurrent writing, False to disable multiple threads. | ||
If enabled, os.cpu_count() is used as the number of threads. | ||
If integer is provided, specified number is used. | ||
boto3_session : boto3.Session(), optional | ||
Boto3 Session. The default boto3 Session will be used if boto3_session receive None. | ||
|
||
|
@@ -232,29 +271,33 @@ def write( | |
""" | ||
measure_cols_names: List[str] = measure_col if isinstance(measure_col, list) else [measure_col] | ||
_logger.debug("measure_cols_names: %s", measure_cols_names) | ||
measure_types: List[str] = [ | ||
_data_types.timestream_type_from_pandas(df[[measure_col_name]]) for measure_col_name in measure_cols_names | ||
] | ||
measure_types: List[str] = _data_types.timestream_type_from_pandas(df.loc[:, measure_cols_names]) | ||
_logger.debug("measure_types: %s", measure_types) | ||
cols_names: List[str] = [time_col] + measure_cols_names + dimensions_cols | ||
_logger.debug("cols_names: %s", cols_names) | ||
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df[cols_names]), max_length=100) | ||
_logger.debug("len(batches): %s", len(batches)) | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: | ||
res: List[List[Any]] = list( | ||
executor.map( | ||
_write_batch, | ||
itertools.repeat(database), | ||
itertools.repeat(table), | ||
itertools.repeat(cols_names), | ||
itertools.repeat(measure_cols_names), | ||
itertools.repeat(measure_types), | ||
itertools.repeat(version), | ||
batches, | ||
itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), | ||
) | ||
dfs = _utils.split_pandas_frame(df.loc[:, cols_names], _utils.ensure_cpu_count(use_threads=use_threads)) | ||
_logger.debug("len(dfs): %s", len(dfs)) | ||
|
||
executor = _get_executor(use_threads=use_threads) | ||
errors = _flatten_list( | ||
ray_get( | ||
[ | ||
_write_df( | ||
df=df, | ||
executor=executor, | ||
database=database, | ||
table=table, | ||
cols_names=cols_names, | ||
measure_cols_names=measure_cols_names, | ||
measure_types=measure_types, | ||
version=version, | ||
boto3_session=boto3_session, | ||
) | ||
for df in dfs | ||
] | ||
) | ||
return [item for sublist in res for item in sublist] | ||
) | ||
return _flatten_list(ray_get(errors)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two |
||
|
||
|
||
def query( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from datetime import datetime | ||
|
||
import pytest | ||
import ray | ||
from pyarrow import csv | ||
|
||
import awswrangler as wr | ||
|
||
from .._utils import ExecutionTimer | ||
|
||
|
||
@pytest.mark.parametrize("benchmark_time", [180]) | ||
def test_real_csv_load_scenario(benchmark_time: int, timestream_database_and_table: str) -> None: | ||
name = timestream_database_and_table | ||
df = ( | ||
ray.data.read_csv( | ||
"https://raw.githubusercontent.com/awslabs/amazon-timestream-tools/mainline/sample_apps/data/sample.csv", | ||
**{ | ||
"read_options": csv.ReadOptions( | ||
column_names=[ | ||
"ignore0", | ||
"region", | ||
"ignore1", | ||
"az", | ||
"ignore2", | ||
"hostname", | ||
"measure_kind", | ||
"measure", | ||
"ignore3", | ||
"ignore4", | ||
"ignore5", | ||
] | ||
) | ||
}, | ||
) | ||
.to_modin() | ||
.loc[:, ["region", "az", "hostname", "measure_kind", "measure"]] | ||
) | ||
|
||
df["time"] = datetime.now() | ||
df.reset_index(inplace=True, drop=False) | ||
df_cpu = df[df.measure_kind == "cpu_utilization"] | ||
df_memory = df[df.measure_kind == "memory_utilization"] | ||
|
||
with ExecutionTimer("elapsed time of wr.timestream.write()") as timer: | ||
rejected_records = wr.timestream.write( | ||
df=df_cpu, | ||
database=name, | ||
table=name, | ||
time_col="time", | ||
measure_col="measure", | ||
dimensions_cols=["index", "region", "az", "hostname"], | ||
) | ||
assert len(rejected_records) == 0 | ||
rejected_records = wr.timestream.write( | ||
df=df_memory, | ||
database=name, | ||
table=name, | ||
time_col="time", | ||
measure_col="measure", | ||
dimensions_cols=["index", "region", "az", "hostname"], | ||
) | ||
assert len(rejected_records) == 0 | ||
assert timer.elapsed_time < benchmark_time | ||
|
||
df = wr.timestream.query(f'SELECT COUNT(*) AS counter FROM "{name}"."{name}"') | ||
assert df["counter"].iloc[0] == 126_000 |
Uh oh!
There was an error while loading. Please reload this page.