Skip to content

Commit a93b80d

Browse files
committed
Fix.
1 parent d2844e0 commit a93b80d

File tree

3 files changed

+17
-21
lines changed

3 files changed

+17
-21
lines changed

src/gbm/gbtree.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ std::unique_ptr<Predictor> const& GBTree::GetPredictor(HostDeviceVector<float> c
620620
auto on_device = is_ellpack || is_from_device;
621621

622622
// Use GPU Predictor if data is already on device and gpu_id is set.
623-
if (on_device && ctx_->gpu_id >= 0) {
623+
if (on_device && ctx_->IsCUDA()) {
624624
#if defined(XGBOOST_USE_CUDA)
625625
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
626626
CHECK(gpu_predictor_);

tests/cpp/helpers.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
395395
for (auto const& page : out->GetBatches<SparsePage>()) {
396396
page.data.SetDevice(device_);
397397
page.offset.SetDevice(device_);
398+
// pull to device
399+
page.data.ConstDeviceSpan();
400+
page.offset.ConstDeviceSpan();
398401
}
399402
}
400403
if (!ft_.empty()) {

tests/cpp/predictor/test_predictor.cc

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -365,23 +365,23 @@ void TestCategoricalPredictLeafColumnSplit(Context const *ctx) {
365365

366366
void TestIterationRange(Context const* ctx) {
367367
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
368-
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
368+
auto dmat = RandomDataGenerator(kRows, kCols, 0)
369+
.Device(ctx->gpu_id)
370+
.GenerateDMatrix(true, true, kClasses);
369371
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
370372

371373
bool bound = false;
372-
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
374+
bst_layer_t lend{3};
375+
std::unique_ptr<Learner> sliced{learner->Slice(0, lend, 1, &bound)};
373376
ASSERT_FALSE(bound);
374377

375378
HostDeviceVector<float> out_predt_sliced;
376379
HostDeviceVector<float> out_predt_ranged;
377380

378381
// margin
379382
{
380-
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
381-
false, false);
382-
383-
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
384-
false, false);
383+
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
384+
learner->Predict(dmat, true, &out_predt_ranged, 0, lend, false, false, false, false, false);
385385

386386
auto const &h_sliced = out_predt_sliced.HostVector();
387387
auto const &h_range = out_predt_ranged.HostVector();
@@ -391,11 +391,8 @@ void TestIterationRange(Context const* ctx) {
391391

392392
// SHAP
393393
{
394-
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
395-
true, false, false);
396-
397-
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
398-
false, false);
394+
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, true, false, false);
395+
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, true, false, false);
399396

400397
auto const &h_sliced = out_predt_sliced.HostVector();
401398
auto const &h_range = out_predt_ranged.HostVector();
@@ -405,10 +402,8 @@ void TestIterationRange(Context const* ctx) {
405402

406403
// SHAP interaction
407404
{
408-
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
409-
false, false, true);
410-
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
411-
false, true);
405+
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, false, false, true);
406+
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, false, false, true);
412407
auto const &h_sliced = out_predt_sliced.HostVector();
413408
auto const &h_range = out_predt_ranged.HostVector();
414409
ASSERT_EQ(h_sliced.size(), h_range.size());
@@ -417,10 +412,8 @@ void TestIterationRange(Context const* ctx) {
417412

418413
// Leaf
419414
{
420-
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
421-
false, false, false);
422-
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
423-
false, false);
415+
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
416+
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, true, false, false, false);
424417
auto const &h_sliced = out_predt_sliced.HostVector();
425418
auto const &h_range = out_predt_ranged.HostVector();
426419
ASSERT_EQ(h_sliced.size(), h_range.size());

0 commit comments

Comments
 (0)