5
5
from __future__ import annotations
6
6
7
7
import operator
8
- from typing import TYPE_CHECKING , Any
8
+ from typing import TYPE_CHECKING , Any , TypedDict
9
9
10
10
import pylibcudf as plc
11
11
import rmm .mr
32
32
_SHUFFLE_METHODS = ("rapidsmpf" , "tasks" )
33
33
34
34
35
+ class ShuffleOptions (TypedDict ):
36
+ """RapidsMPF shuffling options."""
37
+
38
+ on : Sequence [str ]
39
+ column_names : Sequence [str ]
40
+
41
+
35
42
# Experimental rapidsmpf shuffler integration
36
43
class RMPFIntegration : # pragma: no cover
37
44
"""cuDF-Polars protocol for rapidsmpf shuffler."""
38
45
39
46
@staticmethod
40
47
def insert_partition (
41
48
df : DataFrame ,
42
- on : Sequence [ str ],
49
+ partition_id : int , # Not currently used
43
50
partition_count : int ,
44
51
shuffler : Any ,
52
+ options : ShuffleOptions ,
53
+ * other : Any ,
45
54
) -> None :
46
55
"""Add cudf-polars DataFrame chunks to an RMP shuffler."""
47
56
from rapidsmpf .shuffler import partition_and_pack
48
57
58
+ on = options ["on" ]
59
+ assert not other , f"Unexpected arguments: { other } "
49
60
columns_to_hash = tuple (df .column_names .index (val ) for val in on )
50
61
packed_inputs = partition_and_pack (
51
62
df .table ,
@@ -59,13 +70,14 @@ def insert_partition(
59
70
@staticmethod
60
71
def extract_partition (
61
72
partition_id : int ,
62
- column_names : list [str ],
63
73
shuffler : Any ,
74
+ options : ShuffleOptions ,
64
75
) -> DataFrame :
65
76
"""Extract a finished partition from the RMP shuffler."""
66
77
from rapidsmpf .shuffler import unpack_and_concat
67
78
68
79
shuffler .wait_on (partition_id )
80
+ column_names = options ["column_names" ]
69
81
return DataFrame .from_table (
70
82
unpack_and_concat (
71
83
shuffler .extract (partition_id ),
@@ -256,11 +268,10 @@ def _(
256
268
return rapidsmpf_shuffle_graph (
257
269
get_key_name (ir .children [0 ]),
258
270
get_key_name (ir ),
259
- list (ir .schema .keys ()),
260
- shuffle_on ,
261
271
partition_info [ir .children [0 ]].count ,
262
272
partition_info [ir ].count ,
263
273
RMPFIntegration ,
274
+ {"on" : shuffle_on , "column_names" : list (ir .schema .keys ())},
264
275
)
265
276
except (ImportError , ValueError ) as err :
266
277
# ImportError: rapidsmpf is not installed
0 commit comments