Skip to content

Commit 47842d4

Browse files
committed
More tests.
1 parent e62ecc4 commit 47842d4

File tree

5 files changed

+43
-3
lines changed

5 files changed

+43
-3
lines changed

R-package/tests/testthat/test_dmatrix.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
5959
expect_error(setinfo(dtest, 'group', test_label))
6060

6161
# providing character values will give a warning
62-
expect_warning( setinfo(dtest, 'weight', rep('a', nrow(test_data))) )
62+
expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data))))
6363

6464
# any other label should error
6565
expect_error(setinfo(dtest, 'asdf', test_label))

src/data/data.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, s
223223

224224
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
225225
MetaInfo out;
226+
out.num_row_ = ridxs.size();
227+
out.num_col_ = this->num_col_;
226228
// Groups is maintained by a higher level Python function. We should aim at deprecating
227229
// the slice function.
228230
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);

src/data/simple_dmatrix.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
3131
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
3232
h_offset.emplace_back(rptr);
3333
}
34+
out->Info() = this->Info().Slice(ridxs);
35+
out->Info().num_nonzero_ = h_offset.back();
3436
}
35-
out->Info() = this->Info().Slice(ridxs);
3637
return out;
3738
}
3839

tests/cpp/data/test_simple_dmatrix.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ TEST(SimpleDMatrix, Slice) {
222222
size_t constexpr kCols {8};
223223
size_t constexpr kClasses {3};
224224
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
225+
auto& weights = p_m->Info().weights_.HostVector();
226+
weights.resize(kRows);
227+
std::iota(weights.begin(), weights.end(), 0.0f);
228+
225229
auto& lower = p_m->Info().labels_lower_bound_.HostVector();
226230
auto& upper = p_m->Info().labels_upper_bound_.HostVector();
227231
lower.resize(kRows);
@@ -256,6 +260,8 @@ TEST(SimpleDMatrix, Slice) {
256260
out->Info().labels_lower_bound_.HostVector().at(i));
257261
ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(ridx),
258262
out->Info().labels_upper_bound_.HostVector().at(i));
263+
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
264+
out->Info().weights_.HostVector().at(i));
259265

260266
auto& out_margin = out->Info().base_margin_.HostVector();
261267
for (size_t j = 0; j < kClasses; ++j) {
@@ -265,6 +271,10 @@ TEST(SimpleDMatrix, Slice) {
265271
}
266272
}
267273
}
274+
275+
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
276+
ASSERT_EQ(out->Info().num_row_, ridxs.size());
277+
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
268278
}
269279

270280
TEST(SimpleDMatrix, SaveLoadBinary) {

tests/python/test_dmatrix.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,34 @@ def test_np_view(self):
7171
assert (from_view.shape == from_array.shape)
7272
assert (from_view == from_array).all()
7373

74-
def test_feature_names(self):
74+
def test_slice(self):
75+
X = rng.randn(100, 100)
76+
y = rng.randint(low=0, high=3, size=100)
77+
d = xgb.DMatrix(X, y)
78+
eval_res_0 = {}
79+
booster = xgb.train(
80+
{'num_class': 3, 'objective': 'multi:softprob'}, d,
81+
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)
82+
83+
predt = booster.predict(d)
84+
predt = predt.reshape(100 * 3, 1)
85+
d.set_base_margin(predt)
86+
87+
ridxs = [1, 2, 3, 4, 5, 6]
88+
d = d.slice(ridxs)
89+
sliced_margin = d.get_float_info('base_margin')
90+
assert sliced_margin.shape[0] == len(ridxs) * 3
91+
92+
eval_res_1 = {}
93+
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d,
94+
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1)
95+
96+
eval_res_0 = eval_res_0['d']['merror']
97+
eval_res_1 = eval_res_1['d']['merror']
98+
for i in range(len(eval_res_0)):
99+
assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02
100+
101+
def test_feature_names_slice(self):
75102
data = np.random.randn(5, 5)
76103

77104
# different length

0 commit comments

Comments
 (0)