Skip to content

Commit 13b10a6

Browse files
authored
Device dmatrix (#5420)
1 parent 780de49 commit 13b10a6

24 files changed

+913
-308
lines changed

python-package/xgboost/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import sys
99
import warnings
1010

11-
from .core import DMatrix, Booster
11+
from .core import DMatrix, DeviceQuantileDMatrix, Booster
1212
from .training import train, cv
13-
from . import rabit # noqa
13+
from . import rabit # noqa
1414
from . import tracker # noqa
1515
from .tracker import RabitTracker # noqa
1616
from . import dask
17+
1718
try:
1819
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
1920
from .sklearn import XGBRFClassifier, XGBRFRegressor
@@ -31,7 +32,7 @@
3132
with open(VERSION_FILE) as f:
3233
__version__ = f.read().strip()
3334

34-
__all__ = ['DMatrix', 'Booster',
35+
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
3536
'train', 'cv',
3637
'RabitTracker',
3738
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',

python-package/xgboost/core.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,18 @@ def _maybe_pandas_data(data, feature_names, feature_types,
291291
return data, feature_names, feature_types
292292

293293

294+
def _cudf_array_interfaces(df):
295+
'''Extract CuDF __cuda_array_interface__'''
296+
interfaces = []
297+
for col in df:
298+
interface = df[col].__cuda_array_interface__
299+
if 'mask' in interface:
300+
interface['mask'] = interface['mask'].__cuda_array_interface__
301+
interfaces.append(interface)
302+
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
303+
return interfaces_str
304+
305+
294306
def _maybe_cudf_dataframe(data, feature_names, feature_types):
295307
"""Extract internal data from cudf.DataFrame for DMatrix data."""
296308
if not (CUDF_INSTALLED and isinstance(data,
@@ -596,16 +608,10 @@ def _init_from_dt(self, data, nthread):
596608

597609
def _init_from_array_interface_columns(self, df, missing, nthread):
598610
"""Initialize DMatrix from columnar memory format."""
599-
interfaces = []
600-
for col in df:
601-
interface = df[col].__cuda_array_interface__
602-
if 'mask' in interface:
603-
interface['mask'] = interface['mask'].__cuda_array_interface__
604-
interfaces.append(interface)
611+
interfaces_str = _cudf_array_interfaces(df)
605612
handle = ctypes.c_void_p()
606613
missing = missing if missing is not None else np.nan
607614
nthread = nthread if nthread is not None else 1
608-
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
609615
_check_call(
610616
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
611617
interfaces_str,
@@ -1005,6 +1011,65 @@ def feature_types(self, feature_types):
10051011
self._feature_types = feature_types
10061012

10071013

1014+
class DeviceQuantileDMatrix(DMatrix):
1015+
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not
1016+
use this for test/validation tasks as some information may be lost in quantisation. This
1017+
DMatrix is primarily designed to save memory in training and avoids intermediate steps,
1018+
directly creating a compressed representation for training without allocating additional
1019+
memory. Implementation does not currently consider weights in quantisation process(unlike
1020+
DMatrix).
1021+
1022+
You can construct DeviceDMatrix from cupy/cudf
1023+
"""
1024+
1025+
def __init__(self, data, label=None, weight=None, base_margin=None,
1026+
missing=None,
1027+
silent=False,
1028+
feature_names=None,
1029+
feature_types=None,
1030+
nthread=None, max_bin=256):
1031+
self.max_bin = max_bin
1032+
if not (hasattr(data, "__cuda_array_interface__") or (
1033+
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))):
1034+
raise ValueError('Only cupy/cudf currently supported for DeviceDMatrix')
1035+
1036+
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
1037+
missing=missing,
1038+
silent=silent,
1039+
feature_names=feature_names,
1040+
feature_types=feature_types,
1041+
nthread=nthread)
1042+
1043+
def _init_from_array_interface_columns(self, df, missing, nthread):
1044+
"""Initialize DMatrix from columnar memory format."""
1045+
interfaces_str = _cudf_array_interfaces(df)
1046+
handle = ctypes.c_void_p()
1047+
missing = missing if missing is not None else np.nan
1048+
nthread = nthread if nthread is not None else 1
1049+
_check_call(
1050+
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
1051+
interfaces_str,
1052+
ctypes.c_float(missing), ctypes.c_int(nthread),
1053+
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
1054+
self.handle = handle
1055+
1056+
def _init_from_array_interface(self, data, missing, nthread):
1057+
"""Initialize DMatrix from cupy ndarray."""
1058+
interface = data.__cuda_array_interface__
1059+
if 'mask' in interface:
1060+
interface['mask'] = interface['mask'].__cuda_array_interface__
1061+
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
1062+
1063+
handle = ctypes.c_void_p()
1064+
missing = missing if missing is not None else np.nan
1065+
nthread = nthread if nthread is not None else 1
1066+
_check_call(
1067+
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
1068+
interface_str,
1069+
ctypes.c_float(missing), ctypes.c_int(nthread),
1070+
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
1071+
self.handle = handle
1072+
10081073
class Booster(object):
10091074
# pylint: disable=too-many-public-methods
10101075
"""A Booster of XGBoost.

src/c_api/c_api.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "xgboost/learner.h"
55
#include "c_api_error.h"
66
#include "../data/device_adapter.cuh"
7+
#include "../data/device_dmatrix.h"
78

89
using namespace xgboost; // NOLINT
910

@@ -29,3 +30,25 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
2930
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
3031
API_END();
3132
}
33+
34+
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
35+
bst_float missing, int nthread, int max_bin,
36+
DMatrixHandle* out) {
37+
API_BEGIN();
38+
std::string json_str{c_json_strs};
39+
data::CudfAdapter adapter(json_str);
40+
*out =
41+
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
42+
API_END();
43+
}
44+
45+
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
46+
bst_float missing, int nthread, int max_bin,
47+
DMatrixHandle* out) {
48+
API_BEGIN();
49+
std::string json_str{c_json_strs};
50+
data::CupyAdapter adapter(json_str);
51+
*out =
52+
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
53+
API_END();
54+
}

src/common/compressed_iterator.h

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
3232
// the beginning of the array
3333

3434
// The number of bits required to represent a given unsigned range
35-
static size_t SymbolBits(size_t num_symbols) {
36-
auto bits = std::ceil(std::log2(num_symbols));
35+
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
36+
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
3737
return std::max(static_cast<size_t>(bits), size_t(1));
3838
}
3939
} // namespace detail
@@ -50,14 +50,11 @@ static size_t SymbolBits(size_t num_symbols) {
5050
*/
5151

5252
class CompressedBufferWriter {
53-
private:
5453
size_t symbol_bits_;
55-
size_t offset_;
5654

5755
public:
58-
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
59-
symbol_bits_ = detail::SymbolBits(num_symbols);
60-
}
56+
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
57+
: symbol_bits_(detail::SymbolBits(num_symbols)) {}
6158

6259
/**
6360
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
@@ -164,18 +161,15 @@ class CompressedBufferWriter {
164161
}
165162
};
166163

167-
template <typename T>
168-
169164
/**
170-
* \class CompressedIterator
171-
*
172-
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
173-
* host.
165+
* \brief Read symbols from a bit compressed memory buffer. Usable on device and host.
174166
*
175167
* \author Rory
176168
* \date 7/9/2017
169+
*
170+
* \tparam T Generic type parameter.
177171
*/
178-
172+
template <typename T>
179173
class CompressedIterator {
180174
public:
181175
// Type definitions for thrust

src/common/device_helpers.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,4 +1540,12 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
15401540
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
15411541
}
15421542

1543+
1544+
// Thrust version of this function causes error on Windows
1545+
template <typename ReturnT, typename IterT, typename FuncT>
1546+
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
1547+
IterT iter, FuncT func) {
1548+
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
1549+
}
1550+
15431551
} // namespace dh

src/common/hist_util.cu

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -338,31 +338,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
338338
return cuts;
339339
}
340340

341-
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
342-
explicit IsValidFunctor(float missing) : missing(missing) {}
343-
344-
float missing;
345-
__device__ bool operator()(const data::COOTuple& e) const {
346-
if (common::CheckNAN(e.value) || e.value == missing) {
347-
return false;
348-
}
349-
return true;
350-
}
351-
__device__ bool operator()(const Entry& e) const {
352-
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
353-
return false;
354-
}
355-
return true;
356-
}
357-
};
358-
359-
// Thrust version of this function causes error on Windows
360-
template <typename ReturnT, typename IterT, typename FuncT>
361-
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
362-
IterT iter, FuncT func) {
363-
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
364-
}
365-
366341
template <typename AdapterT>
367342
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
368343
SketchContainer* sketch_container, int num_cuts) {
@@ -372,10 +347,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
372347
auto &batch = adapter->Value();
373348
// Enforce single batch
374349
CHECK(!adapter->Next());
375-
auto batch_iter = MakeTransformIterator<data::COOTuple>(
350+
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
376351
thrust::make_counting_iterator(0llu),
377352
[=] __device__(size_t idx) { return batch.GetElement(idx); });
378-
auto entry_iter = MakeTransformIterator<Entry>(
353+
auto entry_iter = dh::MakeTransformIterator<Entry>(
379354
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
380355
return Entry(batch.GetElement(idx).column_idx,
381356
batch.GetElement(idx).value);
@@ -385,7 +360,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
385360
0);
386361

387362
auto d_column_sizes_scan = column_sizes_scan.data().get();
388-
IsValidFunctor is_valid(missing);
363+
data::IsValidFunctor is_valid(missing);
389364
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
390365
auto e = batch_iter[begin + idx];
391366
if (is_valid(e)) {

src/common/hist_util.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ class HistogramCuts {
105105
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
106106
const auto &values = cut_values_.ConstHostVector();
107107
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
108-
if (it == values.cend()) {
109-
it = values.cend() - 1;
110-
}
111108
BinIdx idx = it - values.cbegin();
109+
if (idx == end) {
110+
idx -= 1;
111+
}
112112
return idx;
113113
}
114114

src/data/data.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "array_interface.h"
1111
#include "../common/device_helpers.cuh"
1212
#include "device_adapter.cuh"
13-
#include "simple_dmatrix.h"
13+
#include "device_dmatrix.h"
1414

1515
namespace xgboost {
1616

src/data/device_adapter.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,31 @@
88
#include <memory>
99
#include <string>
1010
#include "../common/device_helpers.cuh"
11+
#include "../common/math.h"
1112
#include "adapter.h"
1213
#include "array_interface.h"
1314

1415
namespace xgboost {
1516
namespace data {
1617

18+
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
19+
explicit IsValidFunctor(float missing) : missing(missing) {}
20+
21+
float missing;
22+
__device__ bool operator()(const data::COOTuple& e) const {
23+
if (common::CheckNAN(e.value) || e.value == missing) {
24+
return false;
25+
}
26+
return true;
27+
}
28+
__device__ bool operator()(const Entry& e) const {
29+
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
30+
return false;
31+
}
32+
return true;
33+
}
34+
};
35+
1736
class CudfAdapterBatch : public detail::NoMetaInfo {
1837
public:
1938
CudfAdapterBatch() = default;

0 commit comments

Comments
 (0)