@@ -365,23 +365,23 @@ void TestCategoricalPredictLeafColumnSplit(Context const *ctx) {
365
365
366
366
void TestIterationRange (Context const * ctx) {
367
367
size_t constexpr kRows = 1000 , kCols = 20 , kClasses = 4 , kForest = 3 , kIters = 10 ;
368
- auto dmat = RandomDataGenerator (kRows , kCols , 0 ).GenerateDMatrix (true , true , kClasses );
368
+ auto dmat = RandomDataGenerator (kRows , kCols , 0 )
369
+ .Device (ctx->gpu_id )
370
+ .GenerateDMatrix (true , true , kClasses );
369
371
auto learner = LearnerForTest (ctx, dmat, kIters , kForest );
370
372
371
373
bool bound = false ;
372
- std::unique_ptr<Learner> sliced {learner->Slice (0 , 3 , 1 , &bound)};
374
+ bst_layer_t lend{3 };
375
+ std::unique_ptr<Learner> sliced{learner->Slice (0 , lend, 1 , &bound)};
373
376
ASSERT_FALSE (bound);
374
377
375
378
HostDeviceVector<float > out_predt_sliced;
376
379
HostDeviceVector<float > out_predt_ranged;
377
380
378
381
// margin
379
382
{
380
- sliced->Predict (dmat, true , &out_predt_sliced, 0 , 0 , false , false , false ,
381
- false , false );
382
-
383
- learner->Predict (dmat, true , &out_predt_ranged, 0 , 3 , false , false , false ,
384
- false , false );
383
+ sliced->Predict (dmat, true , &out_predt_sliced, 0 , 0 , false , false , false , false , false );
384
+ learner->Predict (dmat, true , &out_predt_ranged, 0 , lend, false , false , false , false , false );
385
385
386
386
auto const &h_sliced = out_predt_sliced.HostVector ();
387
387
auto const &h_range = out_predt_ranged.HostVector ();
@@ -391,11 +391,8 @@ void TestIterationRange(Context const* ctx) {
391
391
392
392
// SHAP
393
393
{
394
- sliced->Predict (dmat, false , &out_predt_sliced, 0 , 0 , false , false ,
395
- true , false , false );
396
-
397
- learner->Predict (dmat, false , &out_predt_ranged, 0 , 3 , false , false , true ,
398
- false , false );
394
+ sliced->Predict (dmat, false , &out_predt_sliced, 0 , 0 , false , false , true , false , false );
395
+ learner->Predict (dmat, false , &out_predt_ranged, 0 , lend, false , false , true , false , false );
399
396
400
397
auto const &h_sliced = out_predt_sliced.HostVector ();
401
398
auto const &h_range = out_predt_ranged.HostVector ();
@@ -405,10 +402,8 @@ void TestIterationRange(Context const* ctx) {
405
402
406
403
// SHAP interaction
407
404
{
408
- sliced->Predict (dmat, false , &out_predt_sliced, 0 , 0 , false , false ,
409
- false , false , true );
410
- learner->Predict (dmat, false , &out_predt_ranged, 0 , 3 , false , false , false ,
411
- false , true );
405
+ sliced->Predict (dmat, false , &out_predt_sliced, 0 , 0 , false , false , false , false , true );
406
+ learner->Predict (dmat, false , &out_predt_ranged, 0 , lend, false , false , false , false , true );
412
407
auto const &h_sliced = out_predt_sliced.HostVector ();
413
408
auto const &h_range = out_predt_ranged.HostVector ();
414
409
ASSERT_EQ (h_sliced.size (), h_range.size ());
@@ -417,10 +412,8 @@ void TestIterationRange(Context const* ctx) {
417
412
418
413
// Leaf
419
414
{
420
- sliced->Predict (dmat, false , &out_predt_sliced, 0 , 0 , false , true ,
421
- false , false , false );
422
- learner->Predict (dmat, false , &out_predt_ranged, 0 , 3 , false , true , false ,
423
- false , false );
415
+ sliced->Predict (dmat, false , &out_predt_sliced, 0 , 0 , false , true , false , false , false );
416
+ learner->Predict (dmat, false , &out_predt_ranged, 0 , lend, false , true , false , false , false );
424
417
auto const &h_sliced = out_predt_sliced.HostVector ();
425
418
auto const &h_range = out_predt_ranged.HostVector ();
426
419
ASSERT_EQ (h_sliced.size (), h_range.size ());
0 commit comments