Skip to content

Commit 39390cc

Browse files
authored
[breaking] Remove the predictor param, allow fallback to prediction using DMatrix. (#9129)
- A `DeviceOrd` struct is implemented to indicate the device. It will eventually replace the `gpu_id` parameter. - The `predictor` parameter is removed. - Fallback to `DMatrix` when `inplace_predict` is not available. - The heuristic for choosing a predictor is only used during training.
1 parent 3a0f787 commit 39390cc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1049
-778
lines changed

doc/gpu/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as
4545

4646
.. code-block:: python
4747
48-
model.set_param({"predictor": "gpu_predictor"})
48+
model.set_param({"gpu_id": "0", "tree_method": "gpu_hist"})
4949
shap_values = model.predict(dtrain, pred_contribs=True)
5050
shap_interaction_values = model.predict(dtrain, pred_interactions=True)
5151

doc/parameter.rst

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,6 @@ Parameters for Tree Booster
199199
- Maximum number of discrete bins to bucket continuous features.
200200
- Increasing this number improves the optimality of splits at the cost of higher computation time.
201201

202-
* ``predictor``, [default= ``auto``]
203-
204-
- The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
205-
206-
- ``auto``: Configure predictor based on heuristics.
207-
- ``cpu_predictor``: Multicore CPU prediction algorithm.
208-
- ``gpu_predictor``: Prediction using GPU. Used when ``tree_method`` is ``gpu_hist``.
209-
When ``predictor`` is set to default value ``auto``, the ``gpu_hist`` tree method is
210-
able to provide GPU based prediction without copying training data to GPU memory.
211-
If ``gpu_predictor`` is explicitly specified, then all data is copied into GPU, only
212-
recommended for performing prediction tasks.
213-
214202
* ``num_parallel_tree``, [default=1]
215203

216204
- Number of parallel trees constructed during each iteration. This option is used to support boosted random forest.

doc/prediction.rst

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,6 @@ with the native Python interface :py:meth:`xgboost.Booster.predict` and
8787
behavior. Also the ``save_best`` parameter from :py:obj:`xgboost.callback.EarlyStopping`
8888
might be useful.
8989

90-
*********
91-
Predictor
92-
*********
93-
94-
There are 2 predictors in XGBoost (3 if you have the one-api plugin enabled), namely
95-
``cpu_predictor`` and ``gpu_predictor``. The default option is ``auto`` so that XGBoost
96-
can employ some heuristics for saving GPU memory during training. They might have slight
97-
different outputs due to floating point errors.
98-
9990

10091
***********
10192
Base Margin
@@ -134,15 +125,6 @@ it. Be aware that the output of in-place prediction depends on input data type,
134125
input is on GPU data output is :py:obj:`cupy.ndarray`, otherwise a :py:obj:`numpy.ndarray`
135126
is returned.
136127

137-
****************
138-
Categorical Data
139-
****************
140-
141-
Other than users performing encoding, XGBoost has experimental support for categorical
142-
data using ``gpu_hist`` and ``gpu_predictor``. No special operation needs to be done on
143-
input test data since the information about categories is encoded into the model during
144-
training.
145-
146128
*************
147129
Thread Safety
148130
*************
@@ -159,7 +141,6 @@ instance we might accidentally call ``clf.set_params()`` inside a predict functi
159141
160142
def predict_fn(clf: xgb.XGBClassifier, X):
161143
X = preprocess(X)
162-
clf.set_params(predictor="gpu_predictor") # NOT safe!
163144
clf.set_params(n_jobs=1) # NOT safe!
164145
return clf.predict_proba(X, iteration_range=(0, 10))
165146

doc/tutorials/dask.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ Also for inplace prediction:
148148

149149
.. code-block:: python
150150
151-
booster.set_param({'predictor': 'gpu_predictor'})
152-
# where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
151+
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
152+
booster.set_param({"gpu_id": "0"})
153153
prediction = xgb.dask.inplace_predict(client, booster, X)
154154
155155
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input

doc/tutorials/saving_model.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ Will print out something similar to (not actual output as it's too long for demo
173173
"gradient_booster": {
174174
"gbtree_train_param": {
175175
"num_parallel_tree": "1",
176-
"predictor": "gpu_predictor",
177176
"process_type": "default",
178177
"tree_method": "gpu_hist",
179178
"updater": "grow_gpu_hist",

include/xgboost/base.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <dmlc/omp.h>
1111

1212
#include <cmath>
13+
#include <cstdint>
1314
#include <iostream>
1415
#include <string>
1516
#include <utility>
@@ -112,7 +113,7 @@ using bst_row_t = std::size_t; // NOLINT
112113
/*! \brief Type for tree node index. */
113114
using bst_node_t = std::int32_t; // NOLINT
114115
/*! \brief Type for ranking group index. */
115-
using bst_group_t = std::uint32_t; // NOLINT
116+
using bst_group_t = std::uint32_t; // NOLINT
116117
/**
117118
* \brief Type for indexing into output targets.
118119
*/
@@ -125,6 +126,10 @@ using bst_layer_t = std::int32_t; // NOLINT
125126
* \brief Type for indexing trees.
126127
*/
127128
using bst_tree_t = std::int32_t; // NOLINT
129+
/**
130+
* @brief Ordinal of a CUDA device.
131+
*/
132+
using bst_d_ordinal_t = std::int16_t; // NOLINT
128133

129134
namespace detail {
130135
/*! \brief Implementation of gradient statistics pair. Template specialisation

include/xgboost/c_api.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,9 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat
10671067
/**
10681068
* \brief Inplace prediction from CPU dense matrix.
10691069
*
1070+
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
1071+
* prediction with DMatrix with a performance warning.
1072+
*
10701073
* \param handle Booster handle.
10711074
* \param values JSON encoded __array_interface__ to values.
10721075
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
@@ -1091,6 +1094,9 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
10911094
/**
10921095
* \brief Inplace prediction from CPU CSR matrix.
10931096
*
1097+
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
1098+
* prediction with DMatrix with a performance warning.
1099+
*
10941100
* \param handle Booster handle.
10951101
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
10961102
* \param indices JSON encoded __array_interface__ to column indices in CSR.
@@ -1116,6 +1122,9 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
11161122
/**
11171123
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
11181124
*
1125+
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
1126+
* prediction with DMatrix with a performance warning.
1127+
*
11191128
* \param handle Booster handle
11201129
* \param values JSON encoded __cuda_array_interface__ to values.
11211130
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
@@ -1137,6 +1146,9 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *valu
11371146
/**
11381147
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
11391148
*
1149+
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
1150+
* prediction with DMatrix with a performance warning.
1151+
*
11401152
* \param handle Booster handle
11411153
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
11421154
* \param config See \ref XGBoosterPredictFromDMatrix for more info.

include/xgboost/context.h

Lines changed: 114 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,79 @@
1-
/*!
2-
* Copyright 2014-2022 by Contributors
1+
/**
2+
* Copyright 2014-2023, XGBoost Contributors
33
* \file context.h
44
*/
55
#ifndef XGBOOST_CONTEXT_H_
66
#define XGBOOST_CONTEXT_H_
77

8-
#include <xgboost/logging.h>
9-
#include <xgboost/parameter.h>
8+
#include <xgboost/base.h> // for bst_d_ordinal_t
9+
#include <xgboost/logging.h> // for CHECK_GE
10+
#include <xgboost/parameter.h> // for XGBoostParameter
1011

11-
#include <memory> // std::shared_ptr
12-
#include <string>
12+
#include <cstdint> // for int16_t, int32_t, int64_t
13+
#include <memory> // for shared_ptr
14+
#include <string> // for string, to_string
1315

1416
namespace xgboost {
1517

1618
struct CUDAContext;
1719

20+
/**
21+
* @brief A type for device ordinal. The type is packed into 32-bit for efficient use in
22+
* viewing types like `linalg::TensorView`.
23+
*/
24+
struct DeviceOrd {
25+
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
26+
// CUDA device ordinal.
27+
bst_d_ordinal_t ordinal{-1};
28+
29+
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
30+
[[nodiscard]] bool IsCPU() const { return device == kCPU; }
31+
32+
DeviceOrd() = default;
33+
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
34+
35+
DeviceOrd(DeviceOrd const& that) = default;
36+
DeviceOrd& operator=(DeviceOrd const& that) = default;
37+
DeviceOrd(DeviceOrd&& that) = default;
38+
DeviceOrd& operator=(DeviceOrd&& that) = default;
39+
40+
/**
41+
* @brief Constructor for CPU.
42+
*/
43+
[[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, -1}; }
44+
/**
45+
* @brief Constructor for CUDA device.
46+
*
47+
* @param ordinal CUDA device ordinal.
48+
*/
49+
[[nodiscard]] static auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; }
50+
51+
[[nodiscard]] bool operator==(DeviceOrd const& that) const {
52+
return device == that.device && ordinal == that.ordinal;
53+
}
54+
[[nodiscard]] bool operator!=(DeviceOrd const& that) const { return !(*this == that); }
55+
/**
56+
* @brief Get a string representation of the device and the ordinal.
57+
*/
58+
[[nodiscard]] std::string Name() const {
59+
switch (device) {
60+
case DeviceOrd::kCPU:
61+
return "CPU";
62+
case DeviceOrd::kCUDA:
63+
return "CUDA:" + std::to_string(ordinal);
64+
default: {
65+
LOG(FATAL) << "Unknown device.";
66+
return "";
67+
}
68+
}
69+
}
70+
};
71+
72+
static_assert(sizeof(DeviceOrd) == sizeof(std::int32_t));
73+
74+
/**
75+
* @brief Runtime context for XGBoost. Contains information like threads and device.
76+
*/
1877
struct Context : public XGBoostParameter<Context> {
1978
public:
2079
// Constant representing the device ID of CPU.
@@ -36,29 +95,59 @@ struct Context : public XGBoostParameter<Context> {
3695
// fail when gpu_id is invalid
3796
bool fail_on_invalid_gpu_id{false};
3897
bool validate_parameters{false};
39-
40-
/*!
41-
* \brief Configure the parameter `gpu_id'.
98+
/**
99+
* @brief Configure the parameter `gpu_id'.
42100
*
43-
* \param require_gpu Whether GPU is explicitly required from user.
101+
* @param require_gpu Whether GPU is explicitly required by the user through other
102+
* configurations.
44103
*/
45104
void ConfigureGpuId(bool require_gpu);
46-
/*!
47-
* Return automatically chosen threads.
105+
/**
106+
* @brief Returns the automatically chosen number of threads based on the `nthread`
107+
* parameter and the system settting.
48108
*/
49-
std::int32_t Threads() const;
50-
51-
bool IsCPU() const { return gpu_id == kCpuId; }
52-
bool IsCUDA() const { return !IsCPU(); }
53-
54-
CUDAContext const* CUDACtx() const;
55-
// Make a CUDA context based on the current context.
56-
Context MakeCUDA(std::int32_t device = 0) const {
109+
[[nodiscard]] std::int32_t Threads() const;
110+
/**
111+
* @brief Is XGBoost running on CPU?
112+
*/
113+
[[nodiscard]] bool IsCPU() const { return gpu_id == kCpuId; }
114+
/**
115+
* @brief Is XGBoost running on a CUDA device?
116+
*/
117+
[[nodiscard]] bool IsCUDA() const { return !IsCPU(); }
118+
/**
119+
* @brief Get the current device and ordinal.
120+
*/
121+
[[nodiscard]] DeviceOrd Device() const {
122+
return IsCPU() ? DeviceOrd::CPU() : DeviceOrd::CUDA(static_cast<bst_d_ordinal_t>(gpu_id));
123+
}
124+
/**
125+
* @brief Get the CUDA device ordinal. -1 if XGBoost is running on CPU.
126+
*/
127+
[[nodiscard]] bst_d_ordinal_t Ordinal() const { return this->gpu_id; }
128+
/**
129+
* @brief Name of the current device.
130+
*/
131+
[[nodiscard]] std::string DeviceName() const { return Device().Name(); }
132+
/**
133+
* @brief Get a CUDA device context for allocator and stream.
134+
*/
135+
[[nodiscard]] CUDAContext const* CUDACtx() const;
136+
/**
137+
* @brief Make a CUDA context based on the current context.
138+
*
139+
* @param ordinal The CUDA device ordinal.
140+
*/
141+
[[nodiscard]] Context MakeCUDA(std::int32_t ordinal = 0) const {
57142
Context ctx = *this;
58-
ctx.gpu_id = device;
143+
CHECK_GE(ordinal, 0);
144+
ctx.gpu_id = ordinal;
59145
return ctx;
60146
}
61-
Context MakeCPU() const {
147+
/**
148+
* @brief Make a CPU context based on the current context.
149+
*/
150+
[[nodiscard]] Context MakeCPU() const {
62151
Context ctx = *this;
63152
ctx.gpu_id = kCpuId;
64153
return ctx;
@@ -87,9 +176,9 @@ struct Context : public XGBoostParameter<Context> {
87176
}
88177

89178
private:
90-
// mutable for lazy initialization for cuda context to avoid initializing CUDA at load.
91-
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl
92-
// while trying to hide CUDA code from host compiler.
179+
// mutable for lazy cuda context initialization. This avoids initializing CUDA at load.
180+
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define
181+
// p_impl while trying to hide CUDA code from the host compiler.
93182
mutable std::shared_ptr<CUDAContext> cuctx_;
94183
// cached value for CFS CPU limit. (used in containerized env)
95184
std::int32_t cfs_cpu_count_; // NOLINT

include/xgboost/gbm.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,14 @@ class GradientBooster : public Model, public Configurable {
149149
* \param layer_begin Beginning of boosted tree layer used for prediction.
150150
* \param layer_end End of booster layer. 0 means do not limit trees.
151151
* \param approximate use a faster (inconsistent) approximation of SHAP values
152-
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
153-
* \param condition_feature feature to condition on (i.e. fix) during calculations
154152
*/
155-
virtual void PredictContribution(DMatrix* dmat,
156-
HostDeviceVector<bst_float>* out_contribs,
157-
unsigned layer_begin, unsigned layer_end,
158-
bool approximate = false, int condition = 0,
159-
unsigned condition_feature = 0) = 0;
160-
161-
virtual void PredictInteractionContributions(
162-
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
163-
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
153+
virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
154+
bst_layer_t layer_begin, bst_layer_t layer_end,
155+
bool approximate = false) = 0;
156+
157+
virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
158+
bst_layer_t layer_begin, bst_layer_t layer_end,
159+
bool approximate) = 0;
164160

165161
/*!
166162
* \brief dump the model in the requested format

jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ public void testBooster() throws XGBoostError {
7878
put("num_round", round);
7979
put("num_workers", 1);
8080
put("tree_method", "gpu_hist");
81-
put("predictor", "gpu_predictor");
8281
put("max_bin", maxBin);
8382
}
8483
};

jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ object GpuPreXGBoost extends PreXGBoostProvider {
281281
// - predictor: Force to gpu predictor since native doesn't save predictor.
282282
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
283283
booster.setParam("gpu_id", gpuId.toString)
284-
booster.setParam("predictor", "gpu_predictor")
285284
logger.info("GPU transform on device: " + gpuId)
286285
boosterFlag.isGpuParamsSet = true;
287286
}

0 commit comments

Comments
 (0)