Skip to content

Basic sorting support with Dask #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions python/rapidsmpf/rapidsmpf/examples/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
from typing import TYPE_CHECKING

import dask.dataframe as dd
import numpy as np
from dask.tokenize import tokenize
from dask.utils import M

import rmm.mr
from rmm.pylibrmm.stream import DEFAULT_STREAM

from rapidsmpf.integrations.dask.shuffler import rapidsmpf_shuffle_graph
from rapidsmpf.shuffler import partition_and_pack, unpack_and_concat
from rapidsmpf.shuffler import partition_and_pack, split_and_pack, unpack_and_concat
from rapidsmpf.testing import pylibcudf_to_cudf_dataframe

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any

import dask_cudf

Expand Down Expand Up @@ -45,6 +48,8 @@ def insert_partition(
on: Sequence[str],
partition_count: int,
shuffler: Shuffler,
sort_boundaries: cudf.DataFrame | None,
options: dict[str, Any] | None,
) -> None:
"""
Add cudf DataFrame chunks to an RMPF shuffler.
Expand All @@ -59,22 +64,39 @@ def insert_partition(
Number of output partitions for the current shuffle.
shuffler
The RapidsMPF Shuffler object to extract from.
sort_boundaries
Output partition boundaries for sorting.
options
Optional key-work arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Optional key-work arguments.
Optional key-word arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we use "Additional options." in a few places below. Let's pick one description and copy it through.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, missed this one - Let's do the simple "Additional options." for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to make it DaskIntegration(on=, sort_boundaries=)? Or would that obfuscate/impede the way we build dask graphs here?

More of a question, as it would side-step the immediate need for breaking things (not that it matters) and might avoid the options catch-all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to pass on the pid of df here for sorting. That is needed to find the right splits if you want to balance the result partition sizes for degenerate case (such as all equal).
And I believe we don't have another way to pass it in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seberg - I updated/generalized the protocol a bit. I didn't include the input partition id as a required argument, but we can add that now that we are changing things. Can you explain how having the input partition id would help you handle degenerate values?

Copy link
Contributor

@seberg seberg May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, the idea is that the split_boundary values know which partition ID they came from (and ideally their local row).

For example, we split (1, 1, 1, 1), distributed as pid0=(1, 1) and pid1=(1, 1).
If you the pid and row, then the split boundary will be (value, pid=1, row=0).

With that pid information, you can figure out now here that pid=0 should send it's data to 0 (split after the boundary) and pid=1 should send it all to 1 (split before at boundary here).

Without the additional information, there is no choice but for both pids to send all data to 0.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see. This case definitely isn't a high priority yet (dask-dataframe still doesn't attempt to handle this at all), but it's good-enough reason to include partition_id as a required argument to insert_partition now that we are updating the protocol anyway.

"""
columns_to_hash = tuple(list(df.columns).index(val) for val in on)
packed_inputs = partition_and_pack(
df.to_pylibcudf()[0],
columns_to_hash=columns_to_hash,
num_partitions=partition_count,
stream=DEFAULT_STREAM,
device_mr=rmm.mr.get_current_device_resource(),
)
if options:
raise ValueError(f"Unsupported options: {options}")
if sort_boundaries is None:
columns_to_hash = tuple(list(df.columns).index(val) for val in on)
packed_inputs = partition_and_pack(
df.to_pylibcudf()[0],
columns_to_hash=columns_to_hash,
num_partitions=partition_count,
stream=DEFAULT_STREAM,
device_mr=rmm.mr.get_current_device_resource(),
)
else:
df = df.sort_values(on)
splits = df[on[0]].searchsorted(sort_boundaries, side="right")
Copy link
Contributor

@seberg seberg May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N.B. (I assume you are aware, and at most worth a code comment): Good for an example but it only works if values in sort_boundaries are unique in df. Otherwise you need to adjust for where the boundary value came from. Thus the longer function I shared.

EDIT: Sorry, this is not as bad as I first recalled. As it is only needed to avoid large imbalances in the result partition sizes.

packed_inputs = split_and_pack(
df.to_pylibcudf()[0],
splits.tolist(),
stream=DEFAULT_STREAM,
device_mr=rmm.mr.get_current_device_resource(),
)
shuffler.insert_chunks(packed_inputs)

@staticmethod
def extract_partition(
partition_id: int,
column_names: list[str],
shuffler: Shuffler,
options: dict[str, Any] | None,
) -> cudf.DataFrame:
"""
Extract a finished partition from the RMPF shuffler.
Expand All @@ -87,11 +109,15 @@ def extract_partition(
Sequence of output column names.
shuffler
The RapidsMPF Shuffler object to extract from.
options
Additional options.

Returns
-------
A shuffled DataFrame partition.
"""
if options:
raise ValueError(f"Unsupported options: {options}")
shuffler.wait_on(partition_id)
table = unpack_and_concat(
shuffler.extract(partition_id),
Expand All @@ -108,6 +134,7 @@ def dask_cudf_shuffle(
df: dask_cudf.DataFrame,
shuffle_on: list[str],
*,
sort: bool = False,
partition_count: int | None = None,
) -> dask_cudf.DataFrame:
"""
Expand All @@ -119,6 +146,10 @@ def dask_cudf_shuffle(
Input `dask_cudf.DataFrame` collection.
shuffle_on
List of column names to shuffle on.
sort
Whether the output partitioning should be in
sorted order. The first column in ``shuffle_on``
must be numerical.
partition_count
Output partition count. Default will preserve
the input partition count.
Expand All @@ -133,6 +164,13 @@ def dask_cudf_shuffle(
token = tokenize(df0, shuffle_on, count_out)
name_in = df0._name
name_out = f"shuffle-{token}"
if sort:
boundaries = (
df0[shuffle_on[0]].quantile(np.linspace(0.0, 1.0, count_out)[1:]).optimize()
)
sort_boundaries_name = (boundaries._name, 0)
else:
sort_boundaries_name = None
graph = rapidsmpf_shuffle_graph(
name_in,
name_out,
Expand All @@ -141,16 +179,28 @@ def dask_cudf_shuffle(
count_in,
count_out,
DaskCudfIntegration,
sort_boundaries_name=sort_boundaries_name,
)

# Add df0 dependencies to the task graph
graph.update(df0.dask)
if sort:
graph.update(boundaries.dask)

# Return a Dask-DataFrame collection
return dd.from_graph(
shuffled = dd.from_graph(
graph,
df0._meta,
(None,) * (count_out + 1),
[(name_out, pid) for pid in range(count_out)],
"rapidsmpf",
)

# Return a Dask-DataFrame collection
if sort:
return shuffled.map_partitions(
M.sort_values,
shuffle_on,
meta=shuffled._meta,
)
else:
return shuffled
58 changes: 55 additions & 3 deletions python/rapidsmpf/rapidsmpf/integrations/dask/shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def insert_partition(
on: Sequence[str],
partition_count: int,
shuffler: Shuffler,
sort_boundaries: DataFrameT | None,
options: dict[str, Any] | None,
) -> None:
"""
Add a partition to a RapidsMPF Shuffler.
Expand All @@ -125,13 +127,19 @@ def insert_partition(
Number of output partitions for the current shuffle.
shuffler
The RapidsMPF Shuffler object to extract from.
sort_boundaries
Output partition boundaries for sorting. If None,
hashing will be used to calculate output partitions.
options
Additional options.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine that we might eventually want more all-to-all-like patterns. Would it make more sense to change this interface such that insert_partition just takes the list[PackedData] and the shuffler and we provide separate functions for hash and sort-based partitioning (and the user can bring their own).

So something like:

def insert_partition(
    shuffler: Shuffler,
    chunks: Sequence[PackedData], # Or whatever it is
) -> None:

And we provide two builtin functions

def hash_partition(df, partition_count, *, on) -> list[PackedData]:
    ...
def sort_partition(df, partition_count, *, by) -> list[PackedData]:
    ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that helps us generalize at all. We already have Shuffler.insert_chunks, which is essentially the insert_partition function you are proposing. The purpose of DaskIntegration.insert_partition is to aviod the need for various Dask shuffling applications to write their own task graph.

We want insert_partition/extract_partition to include the minimal necessary arguments to construct a "general" shuffling task graph. Since we are revising things, this may be:

    @staticmethod
    def insert_partition(
        df: DataFrameT,  # Partition to insert
        partition_count: int,   # Output partition count
        shuffler: Shuffler,   # Shuffler object
        options: dict[str, Any] | None,   # Arbitrary keyword arguments
        *other: Any,   # "Other" task-output data (e.g. sorting boundaries/quantiles)
    ) -> None:

    @staticmethod
    def extract_partition(
        partition_id: int,  # Partition ID to extract
        shuffler: Shuffler,   # Shuffler object
        options: dict[str, Any] | None,   # Arbitrary keyword arguments
    ) -> DataFrameT:

I think the options argument can be used to control most variation of a shuffle, and the *other positional argument could be used to pass in information that must be calculated dynamically at execution time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok, carry on then


@staticmethod
def extract_partition(
partition_id: int,
column_names: list[str],
shuffler: Shuffler,
options: dict[str, Any] | None,
) -> DataFrameT:
"""
Extract a DataFrame partition from a RapidsMPF Shuffler.
Expand All @@ -144,6 +152,8 @@ def extract_partition(
Sequence of output column names.
shuffler
The RapidsMPF Shuffler object to extract from.
options
Additional options.

Returns
-------
Expand Down Expand Up @@ -214,11 +224,23 @@ def _stage_shuffler(


def _insert_partition(
callback: Callable[[DataFrameT, Sequence[str], int, Shuffler], None],
callback: Callable[
[
DataFrameT,
Sequence[str],
int,
Shuffler,
str | tuple[str, int] | None,
dict[str, Any],
],
None,
],
df: DataFrameT,
on: Sequence[str],
partition_count: int,
shuffle_id: int,
sort_boundaries_name: str | tuple[str, int] | None,
options: dict[str, Any],
) -> None:
"""
Add a partition to a RapidsMPF Shuffler.
Expand All @@ -237,6 +259,11 @@ def _insert_partition(
Number of output partitions for the current shuffle.
shuffle_id
The RapidsMPF shuffle id.
sort_boundaries_name
The task name for sorting boundaries. Only needed
if the shuffle is in service of a sort operation.
options
Optional key-word arguments.
"""
if callback is None:
raise ValueError("callback missing in _insert_partition.")
Expand All @@ -247,15 +274,21 @@ def _insert_partition(
on,
partition_count,
get_shuffler(shuffle_id),
sort_boundaries_name,
options,
)


def _extract_partition(
callback: Callable[[int, Sequence[str], Shuffler], DataFrameT],
callback: Callable[
[int, Sequence[str], Shuffler, dict[str, Any] | None],
DataFrameT,
],
shuffle_id: int,
partition_id: int,
column_names: list[str],
worker_barrier: tuple[int, ...],
options: dict[str, Any] | None,
) -> DataFrameT:
"""
Extract a partition from a RapidsMPF Shuffler.
Expand All @@ -275,6 +308,8 @@ def _extract_partition(
worker_barrier
Worker-barrier task dependency. This value should
not be used for compute logic.
options
Additional options.

Returns
-------
Expand All @@ -286,6 +321,7 @@ def _extract_partition(
partition_id,
column_names,
get_shuffler(shuffle_id),
options,
)


Expand All @@ -297,6 +333,9 @@ def rapidsmpf_shuffle_graph(
partition_count_in: int,
partition_count_out: int,
integration: DaskIntegration,
*,
sort_boundaries_name: str | tuple[str, int] | None = None,
options: dict[str, Any] | None = None,
) -> dict[Any, Any]:
"""
Return the task graph for a RapidsMPF shuffle.
Expand All @@ -310,13 +349,23 @@ def rapidsmpf_shuffle_graph(
column_names
Sequence of output column names.
shuffle_on
Sequence of column names to shuffle on (by hash).
Sequence of column names to shuffle on. Output
partitions will be based on the hash of these
columns, unless ``sort_boundaries_name`` is
specified. In the case of sorting, output
partitioning will be based on the first element
of ``shuffle_on`` only.
partition_count_in
Partition count of input collection.
partition_count_out
Partition count of output collection.
integration
Dask-integration specification.
sort_boundaries_name
The task name for sorting boundaries. Only needed
if the shuffle is in service of a sort operation.
options
Optional key-word arguments.

Returns
-------
Expand Down Expand Up @@ -422,6 +471,8 @@ def rapidsmpf_shuffle_graph(
shuffle_on,
partition_count_out,
shuffle_id,
sort_boundaries_name,
options,
)
for pid in range(partition_count_in)
}
Expand Down Expand Up @@ -463,6 +514,7 @@ def rapidsmpf_shuffle_graph(
part_id,
column_names,
(global_barrier_2_name, 0),
options,
)
# Assume round-robin partition assignment
restricted_keys[output_keys[-1]] = worker_ranks[rank]
Expand Down
17 changes: 13 additions & 4 deletions python/rapidsmpf/rapidsmpf/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def get_rank(dask_worker: Worker) -> int:


@pytest.mark.parametrize("partition_count", [None, 3])
def test_dask_cudf_integration(loop: pytest.FixtureDef, partition_count: int) -> None: # noqa: F811
@pytest.mark.parametrize("sort", [True, False])
def test_dask_cudf_integration(
loop: pytest.FixtureDef, # noqa: F811
partition_count: int,
sort: bool, # noqa: FBT001
) -> None:
# Test basic Dask-cuDF integration
pytest.importorskip("dask_cudf")

Expand All @@ -83,14 +88,18 @@ def test_dask_cudf_integration(loop: pytest.FixtureDef, partition_count: int) ->
.to_backend("cudf")
)
partition_count_in = df.npartitions
expect = df.compute().sort_values(["x", "y"])
expect = df.compute().sort_values(["id", "name", "x", "y"])
shuffled = dask_cudf_shuffle(
df,
["name", "id"],
["id", "name"],
sort=sort,
partition_count=partition_count,
)
assert shuffled.npartitions == (partition_count or partition_count_in)
got = shuffled.compute().sort_values(["x", "y"])
got = shuffled.compute()
if sort:
assert got["id"].is_monotonic_increasing
got = got.sort_values(["id", "name", "x", "y"])

dd.assert_eq(expect, got, check_index=False)

Expand Down