Skip to content

Commit d268a2a

Browse files
boxdottrivialfis
andauthored
Thread-safe prediction by making the prediction cache thread-local. (#5853)
Co-authored-by: Jiaming Yuan <[email protected]>
1 parent fa3715f commit d268a2a

File tree

5 files changed

+71
-14
lines changed

5 files changed

+71
-14
lines changed

include/xgboost/predictor.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ struct PredictionCacheEntry {
5555
class PredictionContainer {
5656
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
5757
void ClearExpiredEntries();
58-
std::mutex cache_lock_;
5958

6059
public:
6160
PredictionContainer() = default;

src/learner.cc

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,13 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
221221
using LearnerAPIThreadLocalStore =
222222
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
223223

224+
using ThreadLocalPredictionCache =
225+
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
226+
224227
class LearnerConfiguration : public Learner {
225228
protected:
226229
static std::string const kEvalMetric; // NOLINT
227230

228-
protected:
229-
PredictionContainer cache_;
230-
231231
protected:
232232
std::atomic<bool> need_configuration_;
233233
std::map<std::string, std::string> cfg_;
@@ -244,12 +244,19 @@ class LearnerConfiguration : public Learner {
244244
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
245245
: need_configuration_{true} {
246246
monitor_.Init("Learner");
247+
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
247248
for (std::shared_ptr<DMatrix> const& d : cache) {
248-
cache_.Cache(d, GenericParameter::kCpuId);
249+
local_cache.Cache(d, GenericParameter::kCpuId);
250+
}
251+
}
252+
~LearnerConfiguration() override {
253+
auto local_cache = ThreadLocalPredictionCache::Get();
254+
if (local_cache->find(this) != local_cache->cend()) {
255+
local_cache->erase(this);
249256
}
250257
}
251-
// Configuration before data is known.
252258

259+
// Configuration before data is known.
253260
void Configure() override {
254261
// Varient of double checked lock
255262
if (!this->need_configuration_) { return; }
@@ -316,6 +323,10 @@ class LearnerConfiguration : public Learner {
316323
monitor_.Stop("Configure");
317324
}
318325

326+
virtual PredictionContainer* GetPredictionCache() const {
327+
return &((*ThreadLocalPredictionCache::Get())[this]);
328+
}
329+
319330
void LoadConfig(Json const& in) override {
320331
CHECK(IsA<Object>(in));
321332
Version::Load(in, true);
@@ -511,7 +522,8 @@ class LearnerConfiguration : public Learner {
511522
if (mparam_.num_feature == 0) {
512523
// TODO(hcho3): Change num_feature to 64-bit integer
513524
unsigned num_feature = 0;
514-
for (auto& matrix : cache_.Container()) {
525+
auto local_cache = this->GetPredictionCache();
526+
for (auto& matrix : local_cache->Container()) {
515527
CHECK(matrix.first);
516528
CHECK(!matrix.second.ref.expired());
517529
const uint64_t num_col = matrix.first->Info().num_col_;
@@ -948,7 +960,8 @@ class LearnerImpl : public LearnerIO {
948960
this->CheckDataSplitMode();
949961
this->ValidateDMatrix(train.get(), true);
950962

951-
auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id);
963+
auto local_cache = this->GetPredictionCache();
964+
auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id);
952965

953966
monitor_.Start("PredictRaw");
954967
this->PredictRaw(train.get(), &predt, true);
@@ -973,9 +986,10 @@ class LearnerImpl : public LearnerIO {
973986
}
974987
this->CheckDataSplitMode();
975988
this->ValidateDMatrix(train.get(), true);
976-
this->cache_.Cache(train, generic_parameters_.gpu_id);
989+
auto local_cache = this->GetPredictionCache();
990+
local_cache->Cache(train, generic_parameters_.gpu_id);
977991

978-
gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get()));
992+
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()));
979993
monitor_.Stop("BoostOneIter");
980994
}
981995

@@ -991,9 +1005,11 @@ class LearnerImpl : public LearnerIO {
9911005
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_));
9921006
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
9931007
}
1008+
1009+
auto local_cache = this->GetPredictionCache();
9941010
for (size_t i = 0; i < data_sets.size(); ++i) {
9951011
std::shared_ptr<DMatrix> m = data_sets[i];
996-
auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id);
1012+
auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id);
9971013
this->ValidateDMatrix(m.get(), false);
9981014
this->PredictRaw(m.get(), &predt, false);
9991015

@@ -1030,7 +1046,8 @@ class LearnerImpl : public LearnerIO {
10301046
} else if (pred_leaf) {
10311047
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
10321048
} else {
1033-
auto& prediction = cache_.Cache(data, generic_parameters_.gpu_id);
1049+
auto local_cache = this->GetPredictionCache();
1050+
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
10341051
this->PredictRaw(data.get(), &prediction, training, ntree_limit);
10351052
// Copy the prediction cache to output prediction. out_preds comes from C API
10361053
out_preds->SetDevice(generic_parameters_.gpu_id);

src/predictor/predictor.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ void PredictionContainer::ClearExpiredEntries() {
2626
}
2727

2828
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
29-
std::lock_guard<std::mutex> guard { cache_lock_ };
3029
this->ClearExpiredEntries();
3130
container_[m.get()].ref = m;
3231
if (device != GenericParameter::kCpuId) {

src/tree/updater_quantile_hist.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,5 @@ XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
13841384
[]() {
13851385
return new QuantileHistMaker();
13861386
});
1387-
13881387
} // namespace tree
13891388
} // namespace xgboost

tests/cpp/test_learner.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44
#include <gtest/gtest.h>
55
#include <vector>
6+
#include <thread>
67
#include "helpers.h"
78
#include <dmlc/filesystem.h>
89

@@ -176,6 +177,48 @@ TEST(Learner, JsonModelIO) {
176177
}
177178
}
178179

180+
// Crashes the test runner if there are race condiditions.
181+
//
182+
// Build with additional cmake flags to enable thread sanitizer
183+
// which definitely catches problems. Note that OpenMP needs to be
184+
// disabled, otherwise thread sanitizer will also report false
185+
// positives.
186+
//
187+
// ```
188+
// -DUSE_SANITIZER=ON -DENABLED_SANITIZERS=thread -DUSE_OPENMP=OFF
189+
// ```
190+
TEST(Learner, MultiThreadedPredict) {
191+
size_t constexpr kRows = 1000;
192+
size_t constexpr kCols = 1000;
193+
194+
std::shared_ptr<DMatrix> p_dmat{
195+
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
196+
p_dmat->Info().labels_.Resize(kRows);
197+
CHECK_NE(p_dmat->Info().num_col_, 0);
198+
199+
std::shared_ptr<DMatrix> p_data{
200+
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
201+
CHECK_NE(p_data->Info().num_col_, 0);
202+
203+
std::shared_ptr<Learner> learner{Learner::Create({p_dmat})};
204+
learner->Configure();
205+
206+
std::vector<std::thread> threads;
207+
for (uint32_t thread_id = 0;
208+
thread_id < 2 * std::thread::hardware_concurrency(); ++thread_id) {
209+
threads.emplace_back([learner, p_data] {
210+
size_t constexpr kIters = 10;
211+
auto &entry = learner->GetThreadLocal().prediction_entry;
212+
for (size_t iter = 0; iter < kIters; ++iter) {
213+
learner->Predict(p_data, false, &entry.predictions);
214+
}
215+
});
216+
}
217+
for (auto &thread : threads) {
218+
thread.join();
219+
}
220+
}
221+
179222
TEST(Learner, BinaryModelIO) {
180223
size_t constexpr kRows = 8;
181224
int32_t constexpr kIters = 4;

0 commit comments

Comments
 (0)