Skip to content

Commit 1a08012

Browse files
authored
Implement iterative DMatrix. (#5837)
1 parent 4d277d7 commit 1a08012

15 files changed

+855
-84
lines changed

include/xgboost/c_api.h

Lines changed: 203 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,10 @@
2626
// manually define unsigned long
2727
typedef uint64_t bst_ulong; // NOLINT(*)
2828

29-
3029
/*! \brief handle to DMatrix */
3130
typedef void *DMatrixHandle; // NOLINT(*)
3231
/*! \brief handle to Booster */
3332
typedef void *BoosterHandle; // NOLINT(*)
34-
/*! \brief handle to a data iterator */
35-
typedef void *DataIterHandle; // NOLINT(*)
36-
/*! \brief handle to a internal data holder. */
37-
typedef void *DataHolderHandle; // NOLINT(*)
38-
39-
/*! \brief Mini batch used in XGBoost Data Iteration */
40-
typedef struct { // NOLINT(*)
41-
/*! \brief number of rows in the minibatch */
42-
size_t size;
43-
/* \brief number of columns in the minibatch. */
44-
size_t columns;
45-
/*! \brief row pointer to the rows in the data */
46-
#ifdef __APPLE__
47-
/* Necessary as Java on MacOS defines jlong as long int
48-
* and gcc defines int64_t as long long int. */
49-
long* offset; // NOLINT(*)
50-
#else
51-
int64_t* offset; // NOLINT(*)
52-
#endif // __APPLE__
53-
/*! \brief labels of each instance */
54-
float* label;
55-
/*! \brief weight of each instance, can be NULL */
56-
float* weight;
57-
/*! \brief feature index */
58-
int* index;
59-
/*! \brief feature values */
60-
float* value;
61-
} XGBoostBatchCSR;
6233

6334
/*!
6435
* \brief Return the version of the XGBoost library being currently used.
@@ -71,29 +42,6 @@ typedef struct { // NOLINT(*)
7142
*/
7243
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch);
7344

74-
/*!
75-
* \brief Callback to set the data to handle,
76-
* \param handle The handle to the callback.
77-
* \param batch The data content to be set.
78-
*/
79-
XGB_EXTERN_C typedef int XGBCallbackSetData( // NOLINT(*)
80-
DataHolderHandle handle, XGBoostBatchCSR batch);
81-
82-
/*!
83-
* \brief The data reading callback function.
84-
* The iterator will be able to give subset of batch in the data.
85-
*
86-
* If there is data, the function will call set_function to set the data.
87-
*
88-
* \param data_handle The handle to the callback.
89-
* \param set_function The batch returned by the iterator
90-
* \param set_function_handle The handle to be passed to set function.
91-
* \return 0 if we are reaching the end and batch is not returned.
92-
*/
93-
XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
94-
DataIterHandle data_handle, XGBCallbackSetData *set_function,
95-
DataHolderHandle set_function_handle);
96-
9745
/*!
9846
* \brief get string message of the last error
9947
*
@@ -126,20 +74,6 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
12674
int silent,
12775
DMatrixHandle *out);
12876

129-
/*!
130-
* \brief Create a DMatrix from a data iterator.
131-
* \param data_handle The handle to the data.
132-
* \param callback The callback to get the data.
133-
* \param cache_info Additional information about cache file, can be null.
134-
* \param out The created DMatrix
135-
* \return 0 when success, -1 when failure happens.
136-
*/
137-
XGB_DLL int XGDMatrixCreateFromDataIter(
138-
DataIterHandle data_handle,
139-
XGBCallbackDataIterNext* callback,
140-
const char* cache_info,
141-
DMatrixHandle *out);
142-
14377
/*!
14478
* \brief create a matrix content from CSR format
14579
* \param indptr pointer to row headers
@@ -221,6 +155,189 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data,
221155
bst_ulong ncol,
222156
DMatrixHandle* out,
223157
int nthread);
158+
159+
/*
160+
* ========================== Begin data callback APIs =========================
161+
*
162+
* Short notes for data callback
163+
*
164+
* There are 2 sets of data callbacks for DMatrix. The first one is currently exclusively
165+
* used by JVM packages. It uses `XGBoostBatchCSR` to accept batches for CSR formated
166+
* input, and concatenate them into 1 final big CSR. The related functions are:
167+
*
168+
* - XGBCallbackSetData
169+
* - XGBCallbackDataIterNext
170+
* - XGDMatrixCreateFromDataIter
171+
*
172+
* Another set is used by Quantile based DMatrix (used by hist algorithm) for reducing
173+
* memory usage. Currently only GPU implementation is available. It accept foreign data
174+
* iterators as callbacks and works similar to external memory. For GPU Hist, the data is
175+
* first compressed by quantile sketching then merged. This is particular useful for
176+
* distributed setting as it eliminates 2 copies of data. 1 by a `concat` from external
177+
* library to make the data into a blob for normal DMatrix initialization, another by the
178+
* internal CSR copy of DMatrix. Related functions are:
179+
*
180+
* - XGProxyDMatrixCreate
181+
* - XGDMatrixCallbackNext
182+
* - DataIterResetCallback
183+
* - XGDeviceQuantileDMatrixSetDataCudaArrayInterface
184+
* - XGDeviceQuantileDMatrixSetDataCudaColumnar
185+
* - ... (data setters)
186+
*/
187+
188+
/* ==== First set of callback functions, used exclusively by JVM packages. ==== */
189+
190+
/*! \brief handle to a external data iterator */
191+
typedef void *DataIterHandle; // NOLINT(*)
192+
/*! \brief handle to a internal data holder. */
193+
typedef void *DataHolderHandle; // NOLINT(*)
194+
195+
196+
/*! \brief Mini batch used in XGBoost Data Iteration */
197+
typedef struct { // NOLINT(*)
198+
/*! \brief number of rows in the minibatch */
199+
size_t size;
200+
/* \brief number of columns in the minibatch. */
201+
size_t columns;
202+
/*! \brief row pointer to the rows in the data */
203+
#ifdef __APPLE__
204+
/* Necessary as Java on MacOS defines jlong as long int
205+
* and gcc defines int64_t as long long int. */
206+
long* offset; // NOLINT(*)
207+
#else
208+
int64_t* offset; // NOLINT(*)
209+
#endif // __APPLE__
210+
/*! \brief labels of each instance */
211+
float* label;
212+
/*! \brief weight of each instance, can be NULL */
213+
float* weight;
214+
/*! \brief feature index */
215+
int* index;
216+
/*! \brief feature values */
217+
float* value;
218+
} XGBoostBatchCSR;
219+
220+
/*!
221+
* \brief Callback to set the data to handle,
222+
* \param handle The handle to the callback.
223+
* \param batch The data content to be set.
224+
*/
225+
XGB_EXTERN_C typedef int XGBCallbackSetData( // NOLINT(*)
226+
DataHolderHandle handle, XGBoostBatchCSR batch);
227+
228+
/*!
229+
* \brief The data reading callback function.
230+
* The iterator will be able to give subset of batch in the data.
231+
*
232+
* If there is data, the function will call set_function to set the data.
233+
*
234+
* \param data_handle The handle to the callback.
235+
* \param set_function The batch returned by the iterator
236+
* \param set_function_handle The handle to be passed to set function.
237+
* \return 0 if we are reaching the end and batch is not returned.
238+
*/
239+
XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
240+
DataIterHandle data_handle, XGBCallbackSetData *set_function,
241+
DataHolderHandle set_function_handle);
242+
243+
/*!
244+
* \brief Create a DMatrix from a data iterator.
245+
* \param data_handle The handle to the data.
246+
* \param callback The callback to get the data.
247+
* \param cache_info Additional information about cache file, can be null.
248+
* \param out The created DMatrix
249+
* \return 0 when success, -1 when failure happens.
250+
*/
251+
XGB_DLL int XGDMatrixCreateFromDataIter(
252+
DataIterHandle data_handle,
253+
XGBCallbackDataIterNext* callback,
254+
const char* cache_info,
255+
DMatrixHandle *out);
256+
257+
/* == Second set of callback functions, used by constructing Quantile based DMatrix. ===
258+
*
259+
* Short note for how to use the second set of callback for GPU Hist tree method.
260+
*
261+
* Step 0: Define a data iterator with 2 methods `reset`, and `next`.
262+
* Step 1: Create a DMatrix proxy by `XGProxyDMatrixCreate` and hold the handle.
263+
* Step 2: Pass the iterator handle, proxy handle and 2 methods into
264+
* `XGDeviceQuantileDMatrixCreateFromCallback`.
265+
* Step 3: Call appropriate data setters in `next` functions.
266+
*
267+
* See test_iterative_device_dmatrix.cu or Python interface for examples.
268+
*/
269+
270+
/*!
271+
* \brief Create a DMatrix proxy for setting data, can be free by XGDMatrixFree.
272+
*
273+
* \param out The created Device Quantile DMatrix
274+
*
275+
* \return 0 when success, -1 when failure happens
276+
*/
277+
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out);
278+
279+
/*!
280+
* \brief Callback function prototype for getting next batch of data.
281+
*
282+
* \param iter A handler to the user defined iterator.
283+
*
284+
* \return 0 when success, -1 when failure happens
285+
*/
286+
XGB_EXTERN_C typedef int XGDMatrixCallbackNext(DataIterHandle iter); // NOLINT(*)
287+
288+
/*!
289+
* \brief Callback function prototype for reseting external iterator
290+
*/
291+
XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLINT(*)
292+
293+
/*!
294+
* \brief Create a device DMatrix with data iterator.
295+
*
296+
* \param iter A handle to external data iterator.
297+
* \param proxy A DMatrix proxy handle created by `XGProxyDMatrixCreate`.
298+
* \param reset Callback function reseting the iterator state.
299+
* \param next Callback function yieling the next batch of data.
300+
* \param missing Which value to represent missing value
301+
* \param nthread Number of threads to use, 0 for default.
302+
* \param max_bin Maximum number of bins for building histogram.
303+
* \param out The created Device Quantile DMatrix
304+
*
305+
* \return 0 when success, -1 when failure happens
306+
*/
307+
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
308+
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
309+
XGDMatrixCallbackNext *next, float missing, int nthread, int max_bin,
310+
DMatrixHandle *out);
311+
/*!
312+
* \brief Set data on a DMatrix proxy.
313+
*
314+
* \param handle A DMatrix proxy created by XGProxyDMatrixCreate
315+
* \param c_interface_str Null terminated JSON document string representation of CUDA
316+
* array interface.
317+
*
318+
* \return 0 when success, -1 when failure happens
319+
*/
320+
XGB_DLL int XGDeviceQuantileDMatrixSetDataCudaArrayInterface(
321+
DMatrixHandle handle,
322+
const char* c_interface_str);
323+
/*!
324+
* \brief Set data on a DMatrix proxy.
325+
*
326+
* \param handle A DMatrix proxy created by XGProxyDMatrixCreate
327+
* \param c_interface_str Null terminated JSON document string representation of CUDA
328+
* array interface, with an array of columns.
329+
*
330+
* \return 0 when success, -1 when failure happens
331+
*/
332+
XGB_DLL int XGDeviceQuantileDMatrixSetDataCudaColumnar(
333+
DMatrixHandle handle,
334+
const char* c_interface_str);
335+
/*
336+
* ==========================- End data callback APIs ==========================
337+
*/
338+
339+
340+
224341
/*!
225342
* \brief create a new dmatrix from sliced content of existing matrix
226343
* \param handle instance of data matrix to be sliced
@@ -261,6 +378,18 @@ XGB_DLL int XGDMatrixFree(DMatrixHandle handle);
261378
*/
262379
XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
263380
const char *fname, int silent);
381+
382+
/*!
383+
* \brief Set content in array interface to a content in info.
384+
* \param handle a instance of data matrix
385+
* \param field field name.
386+
* \param c_interface_str JSON string representation of array interface.
387+
* \return 0 when success, -1 when failure happens
388+
*/
389+
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
390+
char const* field,
391+
char const* c_interface_str);
392+
264393
/*!
265394
* \brief set float vector to a content in info
266395
* \param handle a instance of data matrix
@@ -437,6 +566,10 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
437566
int training,
438567
bst_ulong *out_len,
439568
const float **out_result);
569+
570+
/*
571+
* ========================== Begin Serialization APIs =========================
572+
*/
440573
/*
441574
* Short note for serialization APIs. There are 3 different sets of serialization API.
442575
*
@@ -559,6 +692,10 @@ XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len,
559692
*/
560693
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle,
561694
char const *json_parameters);
695+
/*
696+
* =========================== End Serialization APIs ==========================
697+
*/
698+
562699

563700
/*!
564701
* \brief dump model, return array of strings representing model dump

include/xgboost/data.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,33 @@ class DMatrix {
502502
const std::string& cache_prefix = "",
503503
size_t page_size = kPageSize);
504504

505-
virtual DMatrix* Slice(common::Span<int32_t const> ridxs) = 0;
505+
/**
506+
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
507+
*
508+
* \tparam DataIterHandle External iterator type, defined in C API.
509+
* \tparam DMatrixHandle DMatrix handle, defined in C API.
510+
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
511+
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
512+
*
513+
* \param iter External data iterator
514+
* \param proxy A hanlde to ProxyDMatrix
515+
* \param reset Callback for reset
516+
* \param next Callback for next
517+
* \param missing Value that should be treated as missing.
518+
* \param nthread number of threads used for initialization.
519+
* \param max_bin Maximum number of bins.
520+
*
521+
* \return A created quantile based DMatrix.
522+
*/
523+
template <typename DataIterHandle, typename DMatrixHandle,
524+
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
525+
static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
526+
DataIterResetCallback *reset,
527+
XGDMatrixCallbackNext *next, float missing,
528+
int nthread,
529+
int max_bin);
530+
531+
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
506532
/*! \brief page size 32 MB */
507533
static const size_t kPageSize = 32UL << 20UL;
508534

0 commit comments

Comments
 (0)