Skip to content

Commit 2036a41

Browse files
committed
[#47] Adding filesystem support for save_df
... Signed-off-by: Todd Gaugler <[email protected]> ... ... . ...
1 parent 8868835 commit 2036a41

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

raydar/task_tracker/task_tracker.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import logging
44
import os
55
from collections.abc import Iterable
6-
from typing import Dict, List, Optional
6+
from typing import Dict, List, Optional, Type
77

88
import coolname
99
import pandas as pd
1010
import polars as pl
11+
import pyarrow.fs as fs
12+
import pyarrow.parquet as pq
1113
import ray
1214
from packaging.version import Version
1315
from ray.serve import shutdown
@@ -88,8 +90,9 @@ def __init__(
8890
self,
8991
name: str,
9092
namespace: str,
91-
path: Optional[str] = None,
9293
enable_perspective_dashboard: bool = False,
94+
filesystem: Type[fs.FileSystem] = fs.LocalFileSystem,
95+
filesystem_kwargs: Optional[dict] = None,
9396
):
9497
"""An async Ray Actor Class to track task level metadata.
9598
@@ -114,13 +117,13 @@ def __init__(
114117
lifetime="detached",
115118
get_if_exists=True,
116119
).remote(name, namespace)
117-
self.path = path
118120
self.df = None
119121
self.finished_tasks = {}
120122
self.user_defined_metadata = {}
121123
self.perspective_dashboard_enabled = enable_perspective_dashboard
122124
self.pending_tasks = []
123125
self.perspective_table_name = f"{name}_data"
126+
self.filesystem = filesystem(**(filesystem_kwargs or dict()))
124127

125128
# WARNING: Do not move this import. Importing these modules elsewhere can cause
126129
# difficult to diagnose, "There is no current event loop in thread 'ray_client_server_" errors.
@@ -306,14 +309,10 @@ def get_proxy_server(self) -> ray.serve.handle.DeploymentHandle:
306309
return self.proxy_server
307310
raise Exception("This task_tracker has no active proxy_server.")
308311

309-
def save_df(self) -> None:
310-
"""Saves the internally maintained dataframe of task related information from the ray GCS"""
311-
self.get_df()
312-
if self.path is not None and self.df is not None:
313-
logger.info(f"Writing DataFrame to {self.path}")
314-
self.df.write_parquet(self.path)
315-
return True
316-
return False
312+
def save_df(self, path: str) -> None:
313+
"""Saves the internally maintained dataframe of task related information from the ray GCS to a provided path, using the filesystem attribute"""
314+
logger.info(f"Writing DataFrame to {path}")
315+
pq.write_table(self.get_df().to_arrow(), path, filesystem=self.filesystem)
317316

318317
def clear_df(self) -> None:
319318
"""Clears the internally maintained dataframe of task related information from the ray GCS"""
@@ -363,9 +362,9 @@ def get_df(self, process_user_metadata_column=False) -> pl.DataFrame:
363362
return df_with_user_metadata
364363
return df
365364

366-
def save_df(self) -> None:
365+
def save_df(self, path: str) -> None:
367366
"""Save the dataframe used by this object's AsyncMetadataTracker actor"""
368-
return ray.get(self.tracker.save_df.remote())
367+
return ray.get(self.tracker.save_df.remote(path))
369368

370369
def clear(self) -> None:
371370
"""Clear the dataframe used by this object's AsyncMetadataTracker actor"""

raydar/tests/test_task_tracker.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import os
2+
import tempfile
13
import time
24

5+
import pandas as pd
36
import pytest
47
import ray
58
import requests
@@ -39,3 +42,15 @@ def test_get_proxy_server(self):
3942
time.sleep(2)
4043
response = requests.get("http://localhost:8000/tables")
4144
assert eval(response.text) == ["test_table"]
45+
46+
def test_save_df(self):
47+
task_tracker = RayTaskTracker()
48+
refs = [do_some_work.remote() for _ in range(100)]
49+
task_tracker.process(refs)
50+
_ = ray.get(refs)
51+
df = task_tracker.get_df()
52+
with tempfile.TemporaryDirectory() as tempdir:
53+
path = os.path.join(tempdir, "output_dir")
54+
task_tracker.save_df(path)
55+
loaded_df = pd.read_parquet(path)
56+
assert loaded_df.equals(df.to_pandas())

0 commit comments

Comments
 (0)