1
1
// Copyright (c) 2014-2020 by Contributors
2
- #include < dmlc/thread_local.h>
3
2
#include < rabit/rabit.h>
4
3
#include < rabit/c_api.h>
5
4
26
25
27
26
using namespace xgboost ; // NOLINT(*);
28
27
29
- /* ! \brief entry to to easily hold returning information */
30
- struct XGBAPIThreadLocalEntry {
31
- /* ! \brief result holder for returning string */
32
- std::string ret_str;
33
- /* ! \brief result holder for returning strings */
34
- std::vector<std::string> ret_vec_str;
35
- /* ! \brief result holder for returning string pointers */
36
- std::vector<const char *> ret_vec_charp;
37
- /* ! \brief returning float vector. */
38
- std::vector<bst_float> ret_vec_float;
39
- /* ! \brief temp variable of gradient pairs. */
40
- std::vector<GradientPair> tmp_gpair;
41
- };
42
-
43
28
XGB_DLL void XGBoostVersion (int * major, int * minor, int * patch) {
44
29
if (major) {
45
30
*major = XGBOOST_VER_MAJOR;
@@ -52,9 +37,6 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
52
37
}
53
38
}
54
39
55
- // define the threadlocal store.
56
- using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
57
-
58
40
int XGBRegisterLogCallback (void (*callback)(const char *)) {
59
41
API_BEGIN ();
60
42
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get ();
@@ -102,16 +84,16 @@ XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
102
84
int nthread,
103
85
DMatrixHandle* out) {
104
86
API_BEGIN ();
105
- LOG (FATAL) << " Xgboost not compiled with cuda " ;
87
+ LOG (FATAL) << " XGBoost not compiled with CUDA " ;
106
88
API_END ();
107
89
}
108
90
109
91
XGB_DLL int XGDMatrixCreateFromArrayInterface (char const * c_json_strs,
110
- bst_float missing,
111
- int nthread,
112
- DMatrixHandle* out) {
92
+ bst_float missing,
93
+ int nthread,
94
+ DMatrixHandle* out) {
113
95
API_BEGIN ();
114
- LOG (FATAL) << " Xgboost not compiled with cuda " ;
96
+ LOG (FATAL) << " XGBoost not compiled with CUDA " ;
115
97
API_END ();
116
98
}
117
99
@@ -375,7 +357,7 @@ XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle,
375
357
auto * learner = static_cast <Learner*>(handle);
376
358
learner->Configure ();
377
359
learner->SaveConfig (&config);
378
- std::string& raw_str = XGBAPIThreadLocalStore::Get ()-> ret_str ;
360
+ std::string& raw_str = learner-> GetThreadLocal (). ret_str ;
379
361
Json::Dump (config, &raw_str);
380
362
*out_str = raw_str.c_str ();
381
363
*out_len = static_cast <xgboost::bst_ulong>(raw_str.length ());
@@ -422,10 +404,11 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
422
404
const char * evnames[],
423
405
xgboost::bst_ulong len,
424
406
const char ** out_str) {
425
- std::string& eval_str = XGBAPIThreadLocalStore::Get ()->ret_str ;
426
407
API_BEGIN ();
427
408
CHECK_HANDLE ();
428
409
auto * bst = static_cast <Learner*>(handle);
410
+ std::string& eval_str = bst->GetThreadLocal ().ret_str ;
411
+
429
412
std::vector<std::shared_ptr<DMatrix>> data_sets;
430
413
std::vector<std::string> data_names;
431
414
@@ -446,24 +429,22 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
446
429
int32_t training,
447
430
xgboost::bst_ulong *len,
448
431
const bst_float **out_result) {
449
- std::vector<bst_float>& preds =
450
- XGBAPIThreadLocalStore::Get ()->ret_vec_float ;
451
432
API_BEGIN ();
452
433
CHECK_HANDLE ();
453
- auto *bst = static_cast <Learner*>(handle);
434
+ auto *learner = static_cast <Learner*>(handle);
435
+ auto & entry = learner->GetThreadLocal ().prediction_entry ;
454
436
HostDeviceVector<bst_float> tmp_preds;
455
- bst ->Predict (
437
+ learner ->Predict (
456
438
*static_cast <std::shared_ptr<DMatrix>*>(dmat),
457
439
(option_mask & 1 ) != 0 ,
458
- &tmp_preds , ntree_limit,
440
+ &entry. predictions , ntree_limit,
459
441
static_cast <bool >(training),
460
442
(option_mask & 2 ) != 0 ,
461
443
(option_mask & 4 ) != 0 ,
462
444
(option_mask & 8 ) != 0 ,
463
445
(option_mask & 16 ) != 0 );
464
- preds = tmp_preds.HostVector ();
465
- *out_result = dmlc::BeginPtr (preds);
466
- *len = static_cast <xgboost::bst_ulong>(preds.size ());
446
+ *out_result = dmlc::BeginPtr (entry.predictions .ConstHostVector ());
447
+ *len = static_cast <xgboost::bst_ulong>(entry.predictions .Size ());
467
448
API_END ();
468
449
}
469
450
@@ -515,13 +496,14 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
515
496
XGB_DLL int XGBoosterGetModelRaw (BoosterHandle handle,
516
497
xgboost::bst_ulong* out_len,
517
498
const char ** out_dptr) {
518
- std::string& raw_str = XGBAPIThreadLocalStore::Get ()->ret_str ;
519
- raw_str.resize (0 );
520
-
521
499
API_BEGIN ();
522
500
CHECK_HANDLE ();
523
- common::MemoryBufferStream fo (&raw_str);
524
501
auto *learner = static_cast <Learner*>(handle);
502
+ std::string& raw_str = learner->GetThreadLocal ().ret_str ;
503
+ raw_str.resize (0 );
504
+
505
+ common::MemoryBufferStream fo (&raw_str);
506
+
525
507
learner->Configure ();
526
508
learner->SaveModel (&fo);
527
509
*out_dptr = dmlc::BeginPtr (raw_str);
@@ -534,13 +516,12 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
534
516
XGB_DLL int XGBoosterSerializeToBuffer (BoosterHandle handle,
535
517
xgboost::bst_ulong *out_len,
536
518
const char **out_dptr) {
537
- std::string &raw_str = XGBAPIThreadLocalStore::Get ()->ret_str ;
538
- raw_str.resize (0 );
539
-
540
519
API_BEGIN ();
541
520
CHECK_HANDLE ();
542
- common::MemoryBufferStream fo (&raw_str);
543
521
auto *learner = static_cast <Learner*>(handle);
522
+ std::string &raw_str = learner->GetThreadLocal ().ret_str ;
523
+ raw_str.resize (0 );
524
+ common::MemoryBufferStream fo (&raw_str);
544
525
learner->Configure ();
545
526
learner->Save (&fo);
546
527
*out_dptr = dmlc::BeginPtr (raw_str);
@@ -583,16 +564,13 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
583
564
API_END ();
584
565
}
585
566
586
- inline void XGBoostDumpModelImpl (
587
- BoosterHandle handle,
588
- const FeatureMap& fmap,
589
- int with_stats,
590
- const char *format,
591
- xgboost::bst_ulong* len,
592
- const char *** out_models) {
593
- std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get ()->ret_vec_str ;
594
- std::vector<const char *>& charp_vecs = XGBAPIThreadLocalStore::Get ()->ret_vec_charp ;
567
+ inline void XGBoostDumpModelImpl (BoosterHandle handle, const FeatureMap &fmap,
568
+ int with_stats, const char *format,
569
+ xgboost::bst_ulong *len,
570
+ const char ***out_models) {
595
571
auto *bst = static_cast <Learner*>(handle);
572
+ std::vector<std::string>& str_vecs = bst->GetThreadLocal ().ret_vec_str ;
573
+ std::vector<const char *>& charp_vecs = bst->GetThreadLocal ().ret_vec_charp ;
596
574
bst->Configure ();
597
575
str_vecs = bst->DumpModel (fmap, with_stats != 0 , format);
598
576
charp_vecs.resize (str_vecs.size ());
@@ -608,7 +586,10 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
608
586
int with_stats,
609
587
xgboost::bst_ulong* len,
610
588
const char *** out_models) {
589
+ API_BEGIN ();
590
+ CHECK_HANDLE ();
611
591
return XGBoosterDumpModelEx (handle, fmap, with_stats, " text" , len, out_models);
592
+ API_END ();
612
593
}
613
594
614
595
XGB_DLL int XGBoosterDumpModelEx (BoosterHandle handle,
@@ -664,7 +645,7 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
664
645
const char ** out,
665
646
int * success) {
666
647
auto * bst = static_cast <Learner*>(handle);
667
- std::string& ret_str = XGBAPIThreadLocalStore::Get ()-> ret_str ;
648
+ std::string& ret_str = bst-> GetThreadLocal (). ret_str ;
668
649
API_BEGIN ();
669
650
CHECK_HANDLE ();
670
651
if (bst->GetAttr (key, &ret_str)) {
@@ -680,9 +661,9 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
680
661
XGB_DLL int XGBoosterSetAttr (BoosterHandle handle,
681
662
const char * key,
682
663
const char * value) {
683
- auto * bst = static_cast <Learner*>(handle);
684
664
API_BEGIN ();
685
665
CHECK_HANDLE ();
666
+ auto * bst = static_cast <Learner*>(handle);
686
667
if (value == nullptr ) {
687
668
bst->DelAttr (key);
688
669
} else {
@@ -694,12 +675,13 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
694
675
XGB_DLL int XGBoosterGetAttrNames (BoosterHandle handle,
695
676
xgboost::bst_ulong* out_len,
696
677
const char *** out) {
697
- std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get ()->ret_vec_str ;
698
- std::vector<const char *>& charp_vecs = XGBAPIThreadLocalStore::Get ()->ret_vec_charp ;
699
- auto *bst = static_cast <Learner*>(handle);
700
678
API_BEGIN ();
701
679
CHECK_HANDLE ();
702
- str_vecs = bst->GetAttrNames ();
680
+ auto *learner = static_cast <Learner *>(handle);
681
+ std::vector<std::string> &str_vecs = learner->GetThreadLocal ().ret_vec_str ;
682
+ std::vector<const char *> &charp_vecs =
683
+ learner->GetThreadLocal ().ret_vec_charp ;
684
+ str_vecs = learner->GetAttrNames ();
703
685
charp_vecs.resize (str_vecs.size ());
704
686
for (size_t i = 0 ; i < str_vecs.size (); ++i) {
705
687
charp_vecs[i] = str_vecs[i].c_str ();
0 commit comments