Skip to content

Commit c838f81

Browse files
authored
Require dtype argument to cudf_polars Column container (#19193)
Depends on #19075 Following #19091, this PR ensure the `Column` always contains a `DataType` object such that Polars type metadata such as struct field names are preserved Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: #19193
1 parent 9c90453 commit c838f81

File tree

16 files changed

+189
-102
lines changed

16 files changed

+189
-102
lines changed

python/cudf_polars/cudf_polars/containers/column.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import functools
9+
import inspect
910
from typing import TYPE_CHECKING
1011

1112
import polars as pl
@@ -37,6 +38,19 @@
3738
__all__: list[str] = ["Column"]
3839

3940

41+
def _dtype_short_repr_to_dtype(dtype_str: str) -> pl.DataType:
42+
"""Convert a Polars dtype short repr to a Polars dtype."""
43+
# limitations of dtype_short_repr_to_dtype described in
44+
# py-polars/polars/datatypes/convert.py#L299
45+
if dtype_str.startswith("list["):
46+
stripped = dtype_str.removeprefix("list[").removesuffix("]")
47+
return pl.List(_dtype_short_repr_to_dtype(stripped))
48+
pl_type = pl.datatypes.convert.dtype_short_repr_to_dtype(dtype_str)
49+
if pl_type is None:
50+
raise ValueError(f"{dtype_str} was not able to be parsed by Polars.")
51+
return pl_type() if inspect.isclass(pl_type) else pl_type
52+
53+
4054
class Column:
4155
"""An immutable column with sortedness metadata."""
4256

@@ -48,19 +62,17 @@ class Column:
4862
# Optional name, only ever set by evaluation of NamedExpr nodes
4963
# The internal evaluation should not care about the name.
5064
name: str | None
51-
# Optional dtype, used for preserving dtype metadata like
52-
# struct fields
53-
dtype: DataType | None
65+
dtype: DataType
5466

5567
def __init__(
5668
self,
5769
column: plc.Column,
70+
dtype: DataType,
5871
*,
5972
is_sorted: plc.types.Sorted = plc.types.Sorted.NO,
6073
order: plc.types.Order = plc.types.Order.ASCENDING,
6174
null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE,
6275
name: str | None = None,
63-
dtype: DataType | None = None,
6476
):
6577
self.obj = column
6678
self.is_scalar = self.size == 1
@@ -98,12 +110,9 @@ def deserialize_ctor_kwargs(
98110
column_kwargs: ColumnOptions,
99111
) -> DeserializedColumnOptions:
100112
"""Deserialize the constructor kwargs for a Column."""
101-
if (serialized_dtype := column_kwargs.get("dtype", None)) is not None:
102-
dtype: DataType | None = DataType( # pragma: no cover
103-
pl.datatypes.convert.dtype_short_repr_to_dtype(serialized_dtype)
104-
)
105-
else: # pragma: no cover
106-
dtype = None # pragma: no cover
113+
dtype = DataType( # pragma: no cover
114+
_dtype_short_repr_to_dtype(column_kwargs["dtype"])
115+
)
107116
return {
108117
"is_sorted": column_kwargs["is_sorted"],
109118
"order": column_kwargs["order"],
@@ -142,15 +151,12 @@ def serialize(
142151

143152
def serialize_ctor_kwargs(self) -> ColumnOptions:
144153
"""Serialize the constructor kwargs for self."""
145-
serialized_dtype = (
146-
None if self.dtype is None else pl.polars.dtype_str_repr(self.dtype.polars)
147-
)
148154
return {
149155
"is_sorted": self.is_sorted,
150156
"order": self.order,
151157
"null_order": self.null_order,
152158
"name": self.name,
153-
"dtype": serialized_dtype,
159+
"dtype": pl.polars.dtype_str_repr(self.dtype.polars),
154160
}
155161

156162
@functools.cached_property
@@ -406,7 +412,7 @@ def mask_nans(self) -> Self:
406412
if plc.traits.is_floating_point(self.obj.type()):
407413
old_count = self.null_count
408414
mask, new_count = plc.transform.nans_to_nulls(self.obj)
409-
result = type(self)(self.obj.with_mask(mask, new_count))
415+
result = type(self)(self.obj.with_mask(mask, new_count), self.dtype)
410416
if old_count == new_count:
411417
return result.sorted_like(self)
412418
return result
@@ -454,4 +460,4 @@ def slice(self, zlice: Slice | None) -> Self:
454460
conversion.from_polars_slice(zlice, num_rows=self.size),
455461
)
456462
(column,) = table.columns()
457-
return type(self)(column, name=self.name).sorted_like(self)
463+
return type(self)(column, name=self.name, dtype=self.dtype).sorted_like(self)

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
def _create_polars_column_metadata(
30-
name: str | None, dtype: pl.DataType | None
30+
name: str | None, dtype: pl.DataType
3131
) -> plc.interop.ColumnMetadata:
3232
"""Create ColumnMetadata preserving pl.Struct field names."""
3333
if isinstance(dtype, pl.Struct):
@@ -72,6 +72,7 @@ def __init__(self, columns: Iterable[Column]) -> None:
7272
if any(c.name is None for c in columns):
7373
raise ValueError("All columns must have a name")
7474
self.columns = [cast(NamedColumn, c) for c in columns]
75+
self.dtypes = [c.dtype for c in self.columns]
7576
self.column_map = {c.name: c for c in self.columns}
7677
self.table = plc.Table([c.obj for c in self.columns])
7778

@@ -89,12 +90,8 @@ def to_polars(self) -> pl.DataFrame:
8990
# serialise with names we control and rename with that map.
9091
name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
9192
metadata = [
92-
_create_polars_column_metadata(
93-
name,
94-
# Can remove the getattr if we ever consistently set Column.dtype
95-
getattr(col.dtype, "polars", None),
96-
)
97-
for name, col in zip(name_map, self.columns, strict=True)
93+
_create_polars_column_metadata(name, dtype.polars)
94+
for name, dtype in zip(name_map, self.dtypes, strict=True)
9895
]
9996
table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
10097
df = pl.DataFrame(table_with_metadata)
@@ -148,7 +145,9 @@ def from_polars(cls, df: pl.DataFrame) -> Self:
148145
)
149146

150147
@classmethod
151-
def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
148+
def from_table(
149+
cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
150+
) -> Self:
152151
"""
153152
Create from a pylibcudf table.
154153
@@ -158,6 +157,8 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
158157
Pylibcudf table to obtain columns from
159158
names
160159
Names for the columns
160+
dtypes
161+
Dtypes for the columns
161162
162163
Returns
163164
-------
@@ -172,9 +173,8 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
172173
if table.num_columns() != len(names):
173174
raise ValueError("Mismatching name and table length.")
174175
return cls(
175-
# TODO: Pass along dtypes here
176-
Column(c, name=name)
177-
for c, name in zip(table.columns(), names, strict=True)
176+
Column(c, name=name, dtype=dtype)
177+
for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
178178
)
179179

180180
@classmethod
@@ -317,7 +317,11 @@ def select_columns(self, names: Set[str]) -> list[Column]:
317317
def filter(self, mask: Column) -> Self:
318318
"""Return a filtered table given a mask."""
319319
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
320-
return type(self).from_table(table, self.column_names).sorted_like(self)
320+
return (
321+
type(self)
322+
.from_table(table, self.column_names, self.dtypes)
323+
.sorted_like(self)
324+
)
321325

322326
def slice(self, zlice: Slice | None) -> Self:
323327
"""
@@ -338,4 +342,8 @@ def slice(self, zlice: Slice | None) -> Self:
338342
(table,) = plc.copying.slice(
339343
self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
340344
)
341-
return type(self).from_table(table, self.column_names).sorted_like(self)
345+
return (
346+
type(self)
347+
.from_table(table, self.column_names, self.dtypes)
348+
.sorted_like(self)
349+
)

python/cudf_polars/cudf_polars/dsl/expressions/boolean.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,13 @@
1010
from functools import partial, reduce
1111
from typing import TYPE_CHECKING, Any, ClassVar
1212

13-
import pyarrow as pa
14-
1513
import pylibcudf as plc
1614

17-
from cudf_polars.containers import Column
15+
from cudf_polars.containers import Column, DataType
1816
from cudf_polars.dsl.expressions.base import (
1917
ExecutionContext,
2018
Expr,
2119
)
22-
from cudf_polars.dsl.expressions.literal import LiteralColumn
2320
from cudf_polars.utils.versions import POLARS_VERSION_LT_128
2421

2522
if TYPE_CHECKING:
@@ -28,7 +25,7 @@
2825
import polars.type_aliases as pl_types
2926
from polars.polars import _expr_nodes as pl_expr
3027

31-
from cudf_polars.containers import DataFrame, DataType
28+
from cudf_polars.containers import DataFrame
3229

3330
__all__ = ["BooleanFunction"]
3431

@@ -99,15 +96,6 @@ def __init__(
9996
# TODO: If polars IR doesn't put the casts in, we need to
10097
# mimic the supertype promotion rules.
10198
raise NotImplementedError("IsIn doesn't support supertype casting")
102-
if self.name is BooleanFunction.Name.IsIn:
103-
_, haystack = self.children
104-
# TODO: Use pl.List isinstance check once we have https://github.com/rapidsai/cudf/pull/18564
105-
if isinstance(haystack, LiteralColumn) and isinstance(
106-
haystack.value, pa.ListArray
107-
):
108-
raise NotImplementedError(
109-
"IsIn does not support nested list column input"
110-
) # pragma: no cover
11199

112100
@staticmethod
113101
def _distinct(
@@ -302,10 +290,10 @@ def do_evaluate(
302290
needles, haystack = columns
303291
if haystack.obj.type().id() == plc.TypeId.LIST:
304292
# Unwrap values from the list column
305-
haystack = Column(haystack.obj.children()[1])
306-
# TODO: Remove check once Column's require dtype
307-
if needles.dtype is not None:
308-
haystack = haystack.astype(needles.dtype)
293+
haystack = Column(
294+
haystack.obj.children()[1],
295+
dtype=DataType(haystack.dtype.polars.inner),
296+
).astype(needles.dtype)
309297
if haystack.size:
310298
return Column(
311299
plc.search.contains(haystack.obj, needles.obj), dtype=self.dtype

python/cudf_polars/cudf_polars/dsl/expressions/string.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
from enum import IntEnum, auto
1010
from typing import TYPE_CHECKING, Any
1111

12-
import polars as pl
1312
from polars.exceptions import InvalidOperationError
1413

1514
import pylibcudf as plc
1615

17-
from cudf_polars.containers import Column, DataType
16+
from cudf_polars.containers import Column
1817
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
1918
from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn
2019
from cudf_polars.dsl.utils.reshape import broadcast
@@ -24,7 +23,7 @@
2423

2524
from polars.polars import _expr_nodes as pl_expr
2625

27-
from cudf_polars.containers import DataFrame
26+
from cudf_polars.containers import DataFrame, DataType
2827

2928
__all__ = ["StringFunction"]
3029

@@ -212,9 +211,9 @@ def do_evaluate(
212211
"""Evaluate this expression given a dataframe for context."""
213212
if self.name is StringFunction.Name.ConcatHorizontal:
214213
columns = [
215-
Column(child.evaluate(df, context=context).obj).astype(
216-
DataType(pl.String())
217-
)
214+
Column(
215+
child.evaluate(df, context=context).obj, dtype=child.dtype
216+
).astype(self.dtype)
218217
for child in self.children
219218
]
220219

@@ -227,13 +226,12 @@ def do_evaluate(
227226
return Column(
228227
plc.strings.combine.concatenate(
229228
plc.Table([col.obj for col in broadcasted]),
230-
plc.Scalar.from_py(delimiter, plc.DataType(plc.TypeId.STRING)),
231-
None
232-
if ignore_nulls
233-
else plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
229+
plc.Scalar.from_py(delimiter, self.dtype.plc),
230+
None if ignore_nulls else plc.Scalar.from_py(None, self.dtype.plc),
234231
None,
235232
plc.strings.combine.SeparatorOnNulls.NO,
236-
)
233+
),
234+
dtype=self.dtype,
237235
)
238236
elif self.name is StringFunction.Name.ConcatVertical:
239237
(child,) = self.children
@@ -324,20 +322,21 @@ def do_evaluate(
324322
if self.children[1].value is None:
325323
return Column(
326324
plc.Column.from_scalar(
327-
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
325+
plc.Scalar.from_py(None, self.dtype.plc),
328326
column.size,
329-
)
327+
),
328+
self.dtype,
330329
)
331330
elif self.children[1].value == 0:
332331
result = plc.Column.from_scalar(
333-
plc.Scalar.from_py("", plc.DataType(plc.TypeId.STRING)),
332+
plc.Scalar.from_py("", self.dtype.plc),
334333
column.size,
335334
)
336335
if column.obj.null_mask():
337336
result = result.with_mask(
338337
column.obj.null_mask(), column.obj.null_count()
339338
)
340-
return Column(result)
339+
return Column(result, self.dtype)
341340

342341
else:
343342
start = -(self.children[1].value)
@@ -348,7 +347,8 @@ def do_evaluate(
348347
plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
349348
plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
350349
None,
351-
)
350+
),
351+
self.dtype,
352352
)
353353
elif self.name is StringFunction.Name.Head:
354354
column = self.children[0].evaluate(df, context=context)
@@ -359,16 +359,18 @@ def do_evaluate(
359359
if end is None:
360360
return Column(
361361
plc.Column.from_scalar(
362-
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
362+
plc.Scalar.from_py(None, self.dtype.plc),
363363
column.size,
364-
)
364+
),
365+
self.dtype,
365366
)
366367
return Column(
367368
plc.strings.slice.slice_strings(
368369
column.obj,
369370
plc.Scalar.from_py(0, plc.DataType(plc.TypeId.INT32)),
370371
plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
371-
)
372+
),
373+
self.dtype,
372374
)
373375

374376
columns = [child.evaluate(df, context=context) for child in self.children]
@@ -450,7 +452,7 @@ def do_evaluate(
450452
return Column(plc.strings.reverse.reverse(column.obj), dtype=self.dtype)
451453
elif self.name is StringFunction.Name.Titlecase:
452454
(column,) = columns
453-
return Column(plc.strings.capitalize.title(column.obj))
455+
return Column(plc.strings.capitalize.title(column.obj), dtype=self.dtype)
454456
raise NotImplementedError(
455457
f"StringFunction {self.name}"
456458
) # pragma: no cover; handled by init raising

python/cudf_polars/cudf_polars/dsl/expressions/unary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def do_evaluate(
237237
null_order=null_order,
238238
)
239239
elif self.name == "value_counts":
240-
(sort, parallel, name, normalize) = self.options
240+
(sort, _, _, normalize) = self.options
241241
count_agg = [plc.aggregation.count(plc.types.NullPolicy.INCLUDE)]
242242
gb_requests = [
243243
plc.groupby.GroupByRequest(

0 commit comments

Comments
 (0)