@@ -221,13 +221,13 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
221
221
using LearnerAPIThreadLocalStore =
222
222
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
223
223
224
+ using ThreadLocalPredictionCache =
225
+ dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
226
+
224
227
class LearnerConfiguration : public Learner {
225
228
protected:
226
229
static std::string const kEvalMetric ; // NOLINT
227
230
228
- protected:
229
- PredictionContainer cache_;
230
-
231
231
protected:
232
232
std::atomic<bool > need_configuration_;
233
233
std::map<std::string, std::string> cfg_;
@@ -244,12 +244,19 @@ class LearnerConfiguration : public Learner {
244
244
explicit LearnerConfiguration (std::vector<std::shared_ptr<DMatrix> > cache)
245
245
: need_configuration_{true } {
246
246
monitor_.Init (" Learner" );
247
+ auto & local_cache = (*ThreadLocalPredictionCache::Get ())[this ];
247
248
for (std::shared_ptr<DMatrix> const & d : cache) {
248
- cache_.Cache (d, GenericParameter::kCpuId );
249
+ local_cache.Cache (d, GenericParameter::kCpuId );
250
+ }
251
+ }
252
+ ~LearnerConfiguration () override {
253
+ auto local_cache = ThreadLocalPredictionCache::Get ();
254
+ if (local_cache->find (this ) != local_cache->cend ()) {
255
+ local_cache->erase (this );
249
256
}
250
257
}
251
- // Configuration before data is known.
252
258
259
+ // Configuration before data is known.
253
260
void Configure () override {
254
261
// Varient of double checked lock
255
262
if (!this ->need_configuration_ ) { return ; }
@@ -316,6 +323,10 @@ class LearnerConfiguration : public Learner {
316
323
monitor_.Stop (" Configure" );
317
324
}
318
325
326
+ virtual PredictionContainer* GetPredictionCache () const {
327
+ return &((*ThreadLocalPredictionCache::Get ())[this ]);
328
+ }
329
+
319
330
void LoadConfig (Json const & in) override {
320
331
CHECK (IsA<Object>(in));
321
332
Version::Load (in, true );
@@ -511,7 +522,8 @@ class LearnerConfiguration : public Learner {
511
522
if (mparam_.num_feature == 0 ) {
512
523
// TODO(hcho3): Change num_feature to 64-bit integer
513
524
unsigned num_feature = 0 ;
514
- for (auto & matrix : cache_.Container ()) {
525
+ auto local_cache = this ->GetPredictionCache ();
526
+ for (auto & matrix : local_cache->Container ()) {
515
527
CHECK (matrix.first );
516
528
CHECK (!matrix.second .ref .expired ());
517
529
const uint64_t num_col = matrix.first ->Info ().num_col_ ;
@@ -948,7 +960,8 @@ class LearnerImpl : public LearnerIO {
948
960
this ->CheckDataSplitMode ();
949
961
this ->ValidateDMatrix (train.get (), true );
950
962
951
- auto & predt = this ->cache_ .Cache (train, generic_parameters_.gpu_id );
963
+ auto local_cache = this ->GetPredictionCache ();
964
+ auto & predt = local_cache->Cache (train, generic_parameters_.gpu_id );
952
965
953
966
monitor_.Start (" PredictRaw" );
954
967
this ->PredictRaw (train.get (), &predt, true );
@@ -973,9 +986,10 @@ class LearnerImpl : public LearnerIO {
973
986
}
974
987
this ->CheckDataSplitMode ();
975
988
this ->ValidateDMatrix (train.get (), true );
976
- this ->cache_ .Cache (train, generic_parameters_.gpu_id );
989
+ auto local_cache = this ->GetPredictionCache ();
990
+ local_cache->Cache (train, generic_parameters_.gpu_id );
977
991
978
- gbm_->DoBoost (train.get (), in_gpair, &cache_. Entry (train.get ()));
992
+ gbm_->DoBoost (train.get (), in_gpair, &local_cache-> Entry (train.get ()));
979
993
monitor_.Stop (" BoostOneIter" );
980
994
}
981
995
@@ -991,9 +1005,11 @@ class LearnerImpl : public LearnerIO {
991
1005
metrics_.emplace_back (Metric::Create (obj_->DefaultEvalMetric (), &generic_parameters_));
992
1006
metrics_.back ()->Configure ({cfg_.begin (), cfg_.end ()});
993
1007
}
1008
+
1009
+ auto local_cache = this ->GetPredictionCache ();
994
1010
for (size_t i = 0 ; i < data_sets.size (); ++i) {
995
1011
std::shared_ptr<DMatrix> m = data_sets[i];
996
- auto &predt = this -> cache_ . Cache (m, generic_parameters_.gpu_id );
1012
+ auto &predt = local_cache-> Cache (m, generic_parameters_.gpu_id );
997
1013
this ->ValidateDMatrix (m.get (), false );
998
1014
this ->PredictRaw (m.get (), &predt, false );
999
1015
@@ -1030,7 +1046,8 @@ class LearnerImpl : public LearnerIO {
1030
1046
} else if (pred_leaf) {
1031
1047
gbm_->PredictLeaf (data.get (), &out_preds->HostVector (), ntree_limit);
1032
1048
} else {
1033
- auto & prediction = cache_.Cache (data, generic_parameters_.gpu_id );
1049
+ auto local_cache = this ->GetPredictionCache ();
1050
+ auto & prediction = local_cache->Cache (data, generic_parameters_.gpu_id );
1034
1051
this ->PredictRaw (data.get (), &prediction, training, ntree_limit);
1035
1052
// Copy the prediction cache to output prediction. out_preds comes from C API
1036
1053
out_preds->SetDevice (generic_parameters_.gpu_id );
0 commit comments