Skip to content

Commit 9179e68

Browse files
committed
[backport] Fix inplace predict with fallback when base margin is used. (dmlc#9536)
- Copy meta info from proxy DMatrix. - Use `std::call_once` to emit less warnings.
1 parent 06487d3 commit 9179e68

File tree

6 files changed

+62
-63
lines changed

6 files changed

+62
-63
lines changed

src/common/error_msg.cc

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
*/
44
#include "error_msg.h"
55

6+
#include <mutex> // for call_once, once_flag
67
#include <sstream> // for stringstream
78

89
#include "../collective/communicator-inl.h" // for GetRank
10+
#include "xgboost/context.h" // for Context
911
#include "xgboost/logging.h"
1012

1113
namespace xgboost::error {
@@ -26,34 +28,43 @@ void WarnDeprecatedGPUHist() {
2628
}
2729

2830
void WarnManualUpdater() {
29-
bool static thread_local logged{false};
30-
if (logged) {
31-
return;
32-
}
33-
LOG(WARNING)
34-
<< "You have manually specified the `updater` parameter. The `tree_method` parameter "
35-
"will be ignored. Incorrect sequence of updaters will produce undefined "
36-
"behavior. For common uses, we recommend using `tree_method` parameter instead.";
37-
logged = true;
31+
static std::once_flag flag;
32+
std::call_once(flag, [] {
33+
LOG(WARNING)
34+
<< "You have manually specified the `updater` parameter. The `tree_method` parameter "
35+
"will be ignored. Incorrect sequence of updaters will produce undefined "
36+
"behavior. For common uses, we recommend using `tree_method` parameter instead.";
37+
});
3838
}
3939

4040
void WarnDeprecatedGPUId() {
41-
static thread_local bool logged{false};
42-
if (logged) {
43-
return;
44-
}
45-
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
46-
msg += " E.g. device=cpu/cuda/cuda:0";
47-
LOG(WARNING) << msg;
48-
logged = true;
41+
static std::once_flag flag;
42+
std::call_once(flag, [] {
43+
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
44+
msg += " E.g. device=cpu/cuda/cuda:0";
45+
LOG(WARNING) << msg;
46+
});
4947
}
5048

5149
void WarnEmptyDataset() {
52-
static thread_local bool logged{false};
53-
if (logged) {
54-
return;
55-
}
56-
LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank();
57-
logged = true;
50+
static std::once_flag flag;
51+
std::call_once(flag,
52+
[] { LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank(); });
53+
}
54+
55+
void MismatchedDevices(Context const* booster, Context const* data) {
56+
static std::once_flag flag;
57+
std::call_once(flag, [&] {
58+
LOG(WARNING)
59+
<< "Falling back to prediction using DMatrix due to mismatched devices. This might "
60+
"lead to higher memory usage and slower performance. XGBoost is running on: "
61+
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName() << ".\n"
62+
<< R"(Potential solutions:
63+
- Use a data structure that matches the device ordinal in the booster.
64+
- Set the device for booster before call to inplace_predict.
65+
66+
This warning will only be shown once.
67+
)";
68+
});
5869
}
5970
} // namespace xgboost::error

src/common/error_msg.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
#include <limits> // for numeric_limits
1111
#include <string> // for string
1212

13-
#include "xgboost/base.h" // for bst_feature_t
13+
#include "xgboost/base.h" // for bst_feature_t
14+
#include "xgboost/context.h" // for Context
1415
#include "xgboost/logging.h"
1516
#include "xgboost/string_view.h" // for StringView
1617

@@ -94,5 +95,7 @@ constexpr StringView InvalidCUDAOrdinal() {
9495
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
9596
"available for using GPU.";
9697
}
98+
99+
void MismatchedDevices(Context const* booster, Context const* data);
97100
} // namespace xgboost::error
98101
#endif // XGBOOST_COMMON_ERROR_MSG_H_

src/data/proxy_dmatrix.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
5555
}
5656

5757
CHECK(p_fmat) << "Failed to fallback.";
58+
p_fmat->Info() = proxy->Info().Copy();
5859
return p_fmat;
5960
}
6061
} // namespace xgboost::data

src/gbm/gbtree.cc

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,25 +85,6 @@ bool UpdatersMatched(std::vector<std::string> updater_seq,
8585
return name == up->Name();
8686
});
8787
}
88-
89-
void MismatchedDevices(Context const* booster, Context const* data) {
90-
bool thread_local static logged{false};
91-
if (logged) {
92-
return;
93-
}
94-
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. This might "
95-
"lead to higher memory usage and slower performance. XGBoost is running on: "
96-
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName()
97-
<< ".\n"
98-
<< R"(Potential solutions:
99-
- Use a data structure that matches the device ordinal in the booster.
100-
- Set the device for booster before call to inplace_predict.
101-
102-
This warning will only be shown once for each thread. Subsequent warnings made by the
103-
current thread will be suppressed.
104-
)";
105-
logged = true;
106-
}
10788
} // namespace
10889

10990
void GBTree::Configure(Args const& cfg) {
@@ -554,7 +535,7 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
554535
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
555536
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
556537
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
557-
MismatchedDevices(this->ctx_, p_m->Ctx());
538+
error::MismatchedDevices(this->ctx_, p_m->Ctx());
558539
CHECK_EQ(out_preds->version, 0);
559540
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
560541
CHECK(proxy) << error::InplacePredictProxy();
@@ -807,7 +788,7 @@ class Dart : public GBTree {
807788
auto n_groups = model_.learner_model_param->num_output_group;
808789

809790
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
810-
MismatchedDevices(ctx_, p_fmat->Ctx());
791+
error::MismatchedDevices(ctx_, p_fmat->Ctx());
811792
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
812793
CHECK(proxy) << error::InplacePredictProxy();
813794
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);

tests/cpp/gbm/test_gbtree.cu

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,6 @@ void TestInplaceFallback(Context const* ctx) {
5858
HostDeviceVector<float>* out_predt{nullptr};
5959
ConsoleLogger::Configure(Args{{"verbosity", "1"}});
6060
std::string output;
61-
// test whether the warning is raised
62-
#if !defined(_WIN32)
63-
// Windows has issue with CUDA and thread local storage. For some reason, on Windows a
64-
// cudaInitializationError is raised during destruction of `HostDeviceVector`. This
65-
// might be related to https://github.com/dmlc/xgboost/issues/5793
66-
::testing::internal::CaptureStderr();
67-
std::thread{[&] {
68-
// Launch a new thread to ensure a warning is raised as we prevent over-verbose
69-
// warning by using thread-local flags.
70-
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
71-
&out_predt, 0, 0);
72-
}}.join();
73-
output = testing::internal::GetCapturedStderr();
74-
ASSERT_NE(output.find("Falling back"), std::string::npos);
75-
#endif
7661

7762
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
7863
&out_predt, 0, 0);

tests/python-gpu/test_gpu_prediction.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,32 @@ def test_inplace_predict_device_type(self, device: str) -> None:
191191
np.testing.assert_allclose(predt_0, predt_3)
192192
np.testing.assert_allclose(predt_0, predt_4)
193193

194-
def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
194+
def run_inplace_base_margin(
195+
self, device: int, booster: xgb.Booster, dtrain: xgb.DMatrix, X, base_margin
196+
) -> None:
195197
import cupy as cp
196198

199+
booster.set_param({"device": f"cuda:{device}"})
197200
dtrain.set_info(base_margin=base_margin)
198201
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
199202
from_dmatrix = booster.predict(dtrain)
200203
cp.testing.assert_allclose(from_inplace, from_dmatrix)
201204

205+
booster = booster.copy() # clear prediction cache.
206+
booster.set_param({"device": "cpu"})
207+
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
208+
from_dmatrix = booster.predict(dtrain)
209+
cp.testing.assert_allclose(from_inplace, from_dmatrix)
210+
211+
booster = booster.copy() # clear prediction cache.
212+
base_margin = cp.asnumpy(base_margin)
213+
if hasattr(X, "values"):
214+
X = cp.asnumpy(X.values)
215+
booster.set_param({"device": f"cuda:{device}"})
216+
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
217+
from_dmatrix = booster.predict(dtrain)
218+
cp.testing.assert_allclose(from_inplace, from_dmatrix, rtol=1e-6)
219+
202220
def run_inplace_predict_cupy(self, device: int) -> None:
203221
import cupy as cp
204222

@@ -244,7 +262,7 @@ def predict_dense(x):
244262
run_threaded_predict(X, rows, predict_dense)
245263

246264
base_margin = cp_rng.randn(rows)
247-
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
265+
self.run_inplace_base_margin(device, booster, dtrain, X, base_margin)
248266

249267
# Create a wide dataset
250268
X = cp_rng.randn(100, 10000)
@@ -318,7 +336,7 @@ def predict_df(x):
318336
run_threaded_predict(X, rows, predict_df)
319337

320338
base_margin = cudf.Series(rng.randn(rows))
321-
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
339+
self.run_inplace_base_margin(0, booster, dtrain, X, base_margin)
322340

323341
@given(
324342
strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy

0 commit comments

Comments
 (0)