Skip to content

Commit 5609102

Browse files
committed
Remove need to pass DMatrix
1 parent e31c53d commit 5609102

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ private synchronized float[][] predict(DMatrix data,
332332
*/
333333
public float[][] inplace_predict(float[] data,
334334
int num_rows,
335-
int num_features,
336-
DMatrix d_matrix) throws XGBoostError {
337-
return this.inplace_predict(data, num_rows, num_features, d_matrix,
335+
int num_features) throws XGBoostError {
336+
return this.inplace_predict(data, num_rows, num_features,
337+
Float.NaN, false, 0, false, false);
338338
Float.NaN, false, 0, false, false);
339339
}
340340

@@ -356,9 +356,8 @@ public float[][] inplace_predict(float[] data,
356356
public float[][] inplace_predict(float[] data,
357357
int num_rows,
358358
int num_features,
359-
DMatrix d_matrix,
360359
float missing) throws XGBoostError {
361-
return this.inplace_predict(data, num_rows, num_features, d_matrix,
360+
return this.inplace_predict(data, num_rows, num_features,
362361
missing, false, 0, false, false);
363362
}
364363

@@ -383,10 +382,9 @@ public float[][] inplace_predict(float[] data,
383382
public float[][] inplace_predict(float[] data,
384383
int num_rows,
385384
int num_features,
386-
DMatrix d_matrix,
387385
float missing,
388386
boolean outputMargin) throws XGBoostError {
389-
return this.inplace_predict(data, num_rows, num_features, d_matrix, missing,
387+
return this.inplace_predict(data, num_rows, num_features, missing,
390388
outputMargin, 0, false, false);
391389
}
392390

@@ -411,11 +409,10 @@ public float[][] inplace_predict(float[] data,
411409
public float[][] inplace_predict(float[] data,
412410
int num_rows,
413411
int num_features,
414-
DMatrix d_matrix,
415412
float missing,
416413
boolean outputMargin,
417414
int treeLimit) throws XGBoostError {
418-
return this.inplace_predict(data, num_rows, num_features, d_matrix, missing,
415+
return this.inplace_predict(data, num_rows, num_features, missing,
419416
outputMargin, treeLimit, false, false);
420417
}
421418

@@ -437,7 +434,6 @@ public float[][] inplace_predict(float[] data,
437434
public float[][] inplace_predict(float[] data,
438435
int num_rows,
439436
int num_features,
440-
DMatrix d_matrix,
441437
float missing,
442438
boolean outputMargin,
443439
int treeLimit,
@@ -453,10 +449,10 @@ public float[][] inplace_predict(float[] data,
453449
if (predContribs) {
454450
optionMask = 4;
455451
}
456-
452+
DMatrix d_mat = new DMatrix(data, num_rows, num_features, missing);
457453
float[][] rawPredicts = new float[1][];
458454
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterInplacePredict(handle, data, num_rows, num_features,
459-
d_matrix.getHandle(), missing,
455+
d_mat.getHandle(), missing,
460456
optionMask, treeLimit, rawPredicts)); // pass missing and treelimit here?
461457

462458
// System.out.println("Booster.inplace_predict rawPredicts[0].length = " +

jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,16 @@ class InplacePredictThread extends Thread {
9595
float[][] testX;
9696
int test_rows;
9797
int features;
98-
DMatrix dMatrix;
9998
float[][] true_predicts;
10099
Booster booster;
101100
Random rng = new Random();
102101
int n_preds = 100;
103102

104-
public InplacePredictThread(int n, Booster booster, float[][] testX, int test_rows, int features, DMatrix dMatrix, float[][] true_predicts) {
103+
public InplacePredictThread(int n, Booster booster, float[][] testX, int test_rows, int features, float[][] true_predicts) {
105104
this.thread_num = n;
106105
this.booster = booster;
107106
this.testX = testX;
108107
this.test_rows = test_rows;
109-
this.dMatrix = dMatrix;
110108
this.features = features;
111109
this.true_predicts = true_predicts;
112110
}
@@ -122,7 +120,7 @@ public void run() {
122120
int r = this.rng.nextInt(this.test_rows);
123121

124122
// In-place predict a single random row
125-
float[][] predictions = booster.inplace_predict(this.testX[r], 1, this.features, this.dMatrix);
123+
float[][] predictions = booster.inplace_predict(this.testX[r], 1, this.features);
126124

127125
// Confirm results as expected
128126
if (predictions[0][0] != this.true_predicts[r][0]) {
@@ -146,19 +144,17 @@ class InplacePredictionTask implements Callable<Boolean> {
146144
float[][] testX;
147145
int test_rows;
148146
int features;
149-
DMatrix dMatrix;
150147
float[][] true_predicts;
151148
Booster booster;
152149
Random rng = new Random();
153150
int n_preds = 100;
154151

155-
public InplacePredictionTask(int n, Booster booster, float[][] testX, int test_rows, int features, DMatrix dMatrix, float[][] true_predicts) {
152+
public InplacePredictionTask(int n, Booster booster, float[][] testX, int test_rows, int features, float[][] true_predicts) {
156153
this.task_num = n;
157154
this.booster = booster;
158155
this.testX = testX;
159156
this.test_rows = test_rows;
160157
this.features = features;
161-
this.dMatrix = dMatrix;
162158
this.true_predicts = true_predicts;
163159
}
164160

@@ -172,7 +168,7 @@ public Boolean call() throws Exception {
172168
int r = this.rng.nextInt(this.test_rows);
173169

174170
// In-place predict a single random row
175-
float[][] predictions = booster.inplace_predict(this.testX[r], 1, this.features, this.dMatrix);
171+
float[][] predictions = booster.inplace_predict(this.testX[r], 1, this.features);
176172

177173
// Confirm results as expected
178174
if (predictions[0][0] != this.true_predicts[r][0]) {
@@ -335,7 +331,7 @@ public void testBoosterInplacePredict() throws XGBoostError, IOException {
335331
float[][] predicts = booster.predict(testMat);
336332

337333
// inplace prediction
338-
float[][] inplace_predicts = booster.inplace_predict(testX, test_rows, features, testMat);
334+
float[][] inplace_predicts = booster.inplace_predict(testX, test_rows, features);
339335

340336
// Confirm that the two prediction results are identical
341337
TestCase.assertTrue(ArrayComparator.compare(predicts, inplace_predicts));

0 commit comments

Comments
 (0)