Skip to content

Commit 60bae8e

Browse files
committed
dask inplace predict.
1 parent 5d0e1df commit 60bae8e

File tree

4 files changed

+82
-15
lines changed

4 files changed

+82
-15
lines changed

python-package/xgboost/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def ctypes2numpy(cptr, length, dtype):
210210
def ctypes2cupy(cptr, length, dtype):
211211
"""Convert a ctypes pointer array to a cupy array."""
212212
import cupy
213-
mem = cupy.zeros((length.value, ), dtype=dtype, order='C')
213+
mem = cupy.zeros(length.value, dtype=dtype, order='C')
214214
addr = ctypes.cast(cptr, ctypes.c_void_p).value
215215
# pylint: disable=c-extension-no-member,no-member
216216
cupy.cuda.runtime.memcpy(
@@ -487,6 +487,7 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
487487
data, feature_names, feature_types = _convert_dataframes(
488488
data, feature_names, feature_types
489489
)
490+
missing = np.nan if missing is None else missing
490491

491492
if isinstance(data, (STRING_TYPES, os_PathLike)):
492493
handle = ctypes.c_void_p()
@@ -622,6 +623,7 @@ def _init_from_dt(self, data, nthread):
622623
def _init_from_array_interface_columns(self, df, missing, nthread):
623624
"""Initialize DMatrix from columnar memory format."""
624625
interfaces_str = _cudf_array_interfaces(df)
626+
nthread = nthread if nthread is not None else 1
625627
handle = ctypes.c_void_p()
626628
_check_call(
627629
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
@@ -1560,7 +1562,8 @@ def reshape_output(predt, rows):
15601562
preds = ctypes.POINTER(ctypes.c_float)()
15611563

15621564
# once caching is supported, we can pass id(data) as cache id.
1563-
1565+
if isinstance(data, DataFrame):
1566+
data = data.values
15641567
if isinstance(data, np.ndarray):
15651568
assert data.flags.c_contiguous
15661569
arr = np.array(data.reshape(data.size), copy=False,

python-package/xgboost/dask.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .compat import sparse, scipy_sparse
2727
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
2828
from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
29+
from .compat import lazy_isinstance
2930

3031
from .core import DMatrix, Booster, _expect
3132
from .training import train as worker_train
@@ -98,6 +99,9 @@ def concat(value):
9899
return pandas_concat(value, axis=0)
99100
if CUDF_INSTALLED and isinstance(value[0], (CUDF_DataFrame, CUDF_Series)):
100101
return CUDF_concat(value, axis=0)
102+
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
103+
import cupy # pylint: disable=import-error
104+
return cupy.concatenate(value, axis=0)
101105
return dd.multi.concat(list(value), axis=0)
102106

103107

@@ -426,7 +430,6 @@ def dispatched_train(worker_addr):
426430
local_param['n_jobs'] is not None and \
427431
local_param['n_jobs'] != worker.nthreads:
428432
msg = '`n_jobs` is specified. ' + msg
429-
print('local_param[n_jobs]', local_param['n_jobs'])
430433
LOGGER.warning(msg)
431434
else:
432435
local_param['nthread'] = worker.nthreads
@@ -502,11 +505,10 @@ def mapped_predict(partition, is_df):
502505
).result()
503506
return predictions
504507
if isinstance(data, dd.DataFrame):
505-
import dask
506508
predictions = client.submit(
507509
dd.map_partitions,
508510
mapped_predict, data, True,
509-
meta=dask.dataframe.utils.make_meta({'prediction': 'f4'})
511+
meta=dd.utils.make_meta({'prediction': 'f4'})
510512
).result()
511513
return predictions.iloc[:, 0]
512514

@@ -607,27 +609,45 @@ def inplace_predict(client, model, data,
607609
booster = model['booster']
608610
else:
609611
raise TypeError(_expect([Booster, dict], type(model)))
612+
if not isinstance(data, (da.Array, dd.DataFrame)):
613+
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
610614

611-
def dispatched_predict(data):
615+
def mapped_predict(data, is_df):
612616
worker = distributed_get_worker()
613617
booster.set_param({'nthread': worker.nthreads})
614618
prediction = booster.inplace_predict(
615619
data,
616620
iteration_range=iteration_range,
617621
predict_type=predict_type,
618622
missing=missing)
623+
if is_df:
624+
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
625+
import cudf
626+
# There's an error with cudf saying `concat_cudf` got an
627+
# expected argument `ignore_index`. So the test here is just
628+
# place holder. So this is not working yet.
629+
prediction = cudf.DataFrame({'prediction': prediction},
630+
dtype=numpy.float32)
631+
else:
632+
# If it's from pandas, the partition is a numpy array
633+
prediction = DataFrame(prediction, columns=['prediction'],
634+
dtype=numpy.float32)
619635
return prediction
620636

621-
msg = 'Only dask array is supported for inplace prediction'
622-
assert isinstance(data, da.Array), msg
623-
624-
def map_blocks():
625-
predictions = da.map_blocks(dispatched_predict, data, drop_axis=1)
637+
if isinstance(data, da.Array):
638+
predictions = client.submit(
639+
da.map_blocks,
640+
mapped_predict, data, False, drop_axis=1,
641+
dtype=numpy.float32
642+
).result()
626643
return predictions
627-
628-
predictions = client.submit(map_blocks)
629-
import dask
630-
return dask.delayed(predictions).compute()
644+
if isinstance(data, dd.DataFrame):
645+
predictions = client.submit(
646+
dd.map_partitions,
647+
mapped_predict, data, True,
648+
meta=dd.utils.make_meta({'prediction': 'f4'})
649+
).result()
650+
return predictions.iloc[:, 0]
631651

632652

633653
def _evaluation_matrices(client, validation_set, sample_weights, missing):

tests/python-gpu/test_gpu_with_dask.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import numpy as np
44
import unittest
5+
import xgboost
56

67
if sys.platform.startswith("win"):
78
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@@ -29,6 +30,7 @@ class TestDistributedGPU(unittest.TestCase):
2930
def test_dask_dataframe(self):
3031
with LocalCUDACluster() as cluster:
3132
with Client(cluster) as client:
33+
import cupy
3234
X, y = generate_array()
3335

3436
X = dd.from_dask_array(X)
@@ -49,6 +51,42 @@ def test_dask_dataframe(self):
4951
predictions = dxgb.predict(client, out, dtrain).compute()
5052
assert isinstance(predictions, np.ndarray)
5153

54+
# There's an error with cudf saying `concat_cudf` got an
55+
# expected argument `ignore_index`. So the test here is just
56+
# place holder.
57+
58+
# series_predictions = dxgb.inplace_predict(client, out, X)
59+
# assert isinstance(series_predictions, dd.Series)
60+
61+
single_node = out['booster'].predict(
62+
xgboost.DMatrix(X.compute()))
63+
cupy.testing.assert_allclose(single_node, predictions)
64+
65+
@pytest.mark.skipif(**tm.no_cupy())
66+
def test_dask_array(self):
67+
with LocalCUDACluster() as cluster:
68+
with Client(cluster) as client:
69+
import cupy
70+
X, y = generate_array()
71+
72+
X = X.map_blocks(cupy.asarray)
73+
y = y.map_blocks(cupy.asarray)
74+
dtrain = dxgb.DaskDMatrix(client, X, y)
75+
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
76+
dtrain=dtrain,
77+
evals=[(dtrain, 'X')],
78+
num_boost_round=2)
79+
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
80+
inplace_predictions = dxgb.inplace_predict(
81+
client, out, X).compute()
82+
single_node = out['booster'].predict(
83+
xgboost.DMatrix(X.compute()))
84+
np.testing.assert_allclose(single_node, from_dmatrix)
85+
cupy.testing.assert_allclose(
86+
cupy.array(single_node),
87+
inplace_predictions)
88+
89+
5290
@pytest.mark.skipif(**tm.no_dask())
5391
@pytest.mark.skipif(**tm.no_dask_cuda())
5492
@pytest.mark.mgpu

tests/python/test_with_dask.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,14 @@ def test_from_dask_dataframe():
6363
from_df = prediction.compute()
6464

6565
assert isinstance(prediction, dd.Series)
66+
assert np.all(prediction.compute().values == from_dmatrix)
6667
assert np.all(from_dmatrix == from_df.to_numpy())
6768

69+
series_predictions = xgb.dask.inplace_predict(client, booster, X)
70+
assert isinstance(series_predictions, dd.Series)
71+
np.testing.assert_allclose(series_predictions.compute().values,
72+
from_dmatrix)
73+
6874

6975
def test_from_dask_array():
7076
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:

0 commit comments

Comments
 (0)