Skip to content

Commit 5b9425b

Browse files
authored
Revise DaskIntegration protocol to align with rapidsmpf (#18720)
Teeing up this "fix" for the proposed change in rapidsai/rapidsmpf#256 Once that PR is merged, we will want to get this in asap to keep `rapidsmpf` shuffling from breaking. We can update `Sort` in a follow-up PR. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: #18720
1 parent 874ecb4 commit 5b9425b

File tree

1 file changed

+16
-5
lines changed
  • python/cudf_polars/cudf_polars/experimental

1 file changed

+16
-5
lines changed

python/cudf_polars/cudf_polars/experimental/shuffle.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
import operator
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, TypedDict
99

1010
import pylibcudf as plc
1111
import rmm.mr
@@ -32,20 +32,31 @@
3232
_SHUFFLE_METHODS = ("rapidsmpf", "tasks")
3333

3434

35+
class ShuffleOptions(TypedDict):
36+
"""RapidsMPF shuffling options."""
37+
38+
on: Sequence[str]
39+
column_names: Sequence[str]
40+
41+
3542
# Experimental rapidsmpf shuffler integration
3643
class RMPFIntegration: # pragma: no cover
3744
"""cuDF-Polars protocol for rapidsmpf shuffler."""
3845

3946
@staticmethod
4047
def insert_partition(
4148
df: DataFrame,
42-
on: Sequence[str],
49+
partition_id: int, # Not currently used
4350
partition_count: int,
4451
shuffler: Any,
52+
options: ShuffleOptions,
53+
*other: Any,
4554
) -> None:
4655
"""Add cudf-polars DataFrame chunks to an RMP shuffler."""
4756
from rapidsmpf.shuffler import partition_and_pack
4857

58+
on = options["on"]
59+
assert not other, f"Unexpected arguments: {other}"
4960
columns_to_hash = tuple(df.column_names.index(val) for val in on)
5061
packed_inputs = partition_and_pack(
5162
df.table,
@@ -59,13 +70,14 @@ def insert_partition(
5970
@staticmethod
6071
def extract_partition(
6172
partition_id: int,
62-
column_names: list[str],
6373
shuffler: Any,
74+
options: ShuffleOptions,
6475
) -> DataFrame:
6576
"""Extract a finished partition from the RMP shuffler."""
6677
from rapidsmpf.shuffler import unpack_and_concat
6778

6879
shuffler.wait_on(partition_id)
80+
column_names = options["column_names"]
6981
return DataFrame.from_table(
7082
unpack_and_concat(
7183
shuffler.extract(partition_id),
@@ -256,11 +268,10 @@ def _(
256268
return rapidsmpf_shuffle_graph(
257269
get_key_name(ir.children[0]),
258270
get_key_name(ir),
259-
list(ir.schema.keys()),
260-
shuffle_on,
261271
partition_info[ir.children[0]].count,
262272
partition_info[ir].count,
263273
RMPFIntegration,
274+
{"on": shuffle_on, "column_names": list(ir.schema.keys())},
264275
)
265276
except (ImportError, ValueError) as err:
266277
# ImportError: rapidsmpf is not installed

0 commit comments

Comments
 (0)