Skip to content

Commit 7968bca

Browse files
committed
Set dense data
1 parent 9e79c2e commit 7968bca

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

src/c_api/c_api.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ XGB_DLL int XGBoosterInplacePredict(BoosterHandle handle,
10271027
API_BEGIN();
10281028
CHECK_HANDLE();
10291029
xgboost::bst_ulong out_dim;
1030+
std::shared_ptr<xgboost::data::DenseAdapter> x{new xgboost::data::DenseAdapter(data, num_rows, num_features)};
10301031
//std::shared_ptr<DMatrix> p_m(dMatrixHandle);
10311032
std::shared_ptr<DMatrix> p_m{nullptr};
10321033
if (!dMatrixHandle) {
@@ -1042,15 +1043,14 @@ XGB_DLL int XGBoosterInplacePredict(BoosterHandle handle,
10421043
}
10431044
} else {
10441045
p_m = *static_cast<std::shared_ptr<DMatrix> *>(dMatrixHandle);
1045-
10461046
fprintf (stdout, "dmatrix handle is not null");
10471047
if (!p_m) {
10481048
fprintf (stderr, "p_m 2 is null");
10491049
exit(1);
10501050
}
10511051
}
10521052
fprintf (stdout, reinterpret_cast<const char *>(p_m.get()));
1053-
DMatrixProxy* stuff = dynamic_cast<data::DMatrixProxy *>(p_m.get());
1053+
// DMatrixProxy* stuff = dynamic_cast<data::DMatrixProxy *>(p_m.get());
10541054
auto proxy = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);
10551055
if (!proxy) {
10561056
fprintf (stderr, "proxy is null");
@@ -1062,6 +1062,7 @@ XGB_DLL int XGBoosterInplacePredict(BoosterHandle handle,
10621062
}
10631063
auto *learner = static_cast<xgboost::Learner *>(handle);
10641064
auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner);
1065+
proxy->SetDenseData(data)
10651066
InplacePredictImplCore(p_m, learner, (xgboost::PredictionType)0, missing, num_rows, num_features,
10661067
0, iteration_end, true, len, &out_dim, out_result);
10671068
// printf("XGBoosterInplacePredict len = %u, dim = %u\n", **len, out_dim);

src/data/proxy_dmatrix.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ void DMatrixProxy::SetArrayData(char const *c_interface) {
1515
this->ctx_.gpu_id = Context::kCpuId;
1616
}
1717

18+
void DMatrixProxy::SetDenseData(const float *data) {
19+
std::shared_ptr<xgboost::data::DenseAdapter> adapter{new xgboost::data::DenseAdapter(data)};
20+
this->batch_ = adapter;
21+
this->Info().num_col_ = adapter->NumColumns();
22+
this->Info().num_row_ = adapter->NumRows();
23+
this->ctx_.gpu_id = Context::kCpuId;
24+
}
25+
1826
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
1927
char const *c_values, bst_feature_t n_features, bool on_host) {
2028
CHECK(on_host) << "Not implemented on device.";

src/data/proxy_dmatrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class DMatrixProxy : public DMatrix {
7070
}
7171

7272
void SetArrayData(char const* c_interface);
73+
void SetDenseData(const float *data);
7374
void SetCSRData(char const *c_indptr, char const *c_indices,
7475
char const *c_values, bst_feature_t n_features,
7576
bool on_host);

0 commit comments

Comments
 (0)