Skip to content

Commit b47b5ac

Browse files
authored
Use hypothesis (dmlc#5759)
* Use hypothesis * Allow int64 array interface for groups * Add packages to Windows CI * Add to travis * Make sure device index is set correctly * Fix dask-cudf test * appveyor
1 parent 02884b0 commit b47b5ac

17 files changed

+414
-442
lines changed

Jenkinsfile-win64

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def TestWin64CPU() {
113113
"""
114114
echo "Installing Python dependencies..."
115115
bat """
116-
conda activate && conda upgrade scikit-learn pandas numpy
116+
conda activate && conda install -y hypothesis && conda upgrade scikit-learn pandas numpy hypothesis
117117
"""
118118
echo "Running Python tests..."
119119
bat "conda activate && python -m pytest -v -s --fulltrace tests\\python"
@@ -138,7 +138,7 @@ def TestWin64GPU(args) {
138138
"""
139139
echo "Installing Python dependencies..."
140140
bat """
141-
conda activate && conda upgrade scikit-learn pandas numpy
141+
conda activate && conda install -y hypothesis && conda upgrade scikit-learn pandas numpy hypothesis
142142
"""
143143
echo "Running Python tests..."
144144
bat """

appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ install:
4444
- if /i "%DO_PYTHON%" == "on" (
4545
conda config --set always_yes true &&
4646
conda update -q conda &&
47-
conda install -y numpy scipy pandas matplotlib pytest scikit-learn graphviz python-graphviz
47+
conda install -y numpy scipy pandas matplotlib pytest scikit-learn graphviz python-graphviz hypothesis
4848
)
4949
- set PATH=C:\Miniconda3-x64\Library\bin\graphviz;%PATH%
5050
# R: based on https://github.com/krlmlr/r-appveyor

src/data/data.cu

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,30 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
3434
});
3535
}
3636

37+
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
38+
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
39+
<< "Expected integer metainfo";
40+
auto SetDeviceToPtr = [](void* ptr) {
41+
cudaPointerAttributes attr;
42+
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
43+
int32_t ptr_device = attr.device;
44+
dh::safe_cuda(cudaSetDevice(ptr_device));
45+
return ptr_device;
46+
};
47+
auto ptr_device = SetDeviceToPtr(column.data);
48+
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
49+
auto d_tmp = temp.data();
50+
51+
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
52+
d_tmp[idx] = column.GetElement(idx);
53+
});
54+
auto length = column.num_rows;
55+
out->resize(length + 1);
56+
out->at(0) = 0;
57+
thrust::copy(temp.data(), temp.data() + length, out->begin() + 1);
58+
std::partial_sum(out->begin(), out->end(), out->begin());
59+
}
60+
3761
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
3862
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
3963
auto const& j_arr = get<Array>(j_interface);
@@ -53,16 +77,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
5377
} else if (key == "base_margin") {
5478
CopyInfoImpl(array_interface, &base_margin_);
5579
} else if (key == "group") {
56-
// Ranking is not performed on device.
57-
thrust::device_ptr<uint32_t> p_src{
58-
reinterpret_cast<uint32_t*>(array_interface.data)};
59-
60-
auto length = array_interface.num_rows;
61-
group_ptr_.resize(length + 1);
62-
group_ptr_[0] = 0;
63-
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
64-
std::partial_sum(group_ptr_.begin(), group_ptr_.end(), group_ptr_.begin());
65-
80+
CopyGroupInfoImpl(array_interface, &group_ptr_);
6681
return;
6782
} else {
6883
LOG(FATAL) << "Unknown metainfo: " << key;

tests/ci_build/Dockerfile.cpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ENV GOSU_VERSION 1.10
2222
# Install Python packages in default env
2323
RUN \
2424
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh \
25-
recommonmark guzzle_sphinx_theme mock breathe graphviz \
25+
recommonmark guzzle_sphinx_theme mock breathe graphviz hypothesis\
2626
pytest scikit-learn wheel kubernetes urllib3 jsonschema boto3 && \
2727
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
2828
pip install "dask[complete]"

tests/ci_build/Dockerfile.cudf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ENV PATH=/opt/python/bin:$PATH
1919
RUN \
2020
conda create -n cudf_test -c rapidsai -c nvidia -c conda-forge -c defaults \
2121
python=3.7 cudf cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
22-
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz
22+
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
2323

2424
ENV GOSU_VERSION 1.10
2525

tests/ci_build/Dockerfile.gpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ENV PATH=/opt/python/bin:$PATH
1818
RUN \
1919
conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \
2020
python=3.7 dask dask-cuda numpy pytest scipy scikit-learn pandas \
21-
matplotlib wheel python-kubernetes urllib3 graphviz
21+
matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
2222

2323
ENV GOSU_VERSION 1.10
2424

tests/cpp/data/test_metainfo.cu

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
2121

2222
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
2323
column["shape"] = Array(j_shape);
24-
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
24+
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))});
2525
column["version"] = Integer(static_cast<Integer::Int>(1));
2626
column["typestr"] = String(typestr);
2727

@@ -78,16 +78,32 @@ TEST(MetaInfo, FromInterface) {
7878

7979
TEST(MetaInfo, Group) {
8080
cudaSetDevice(0);
81-
thrust::device_vector<uint32_t> d_data;
82-
std::string str = PrepareData<uint32_t>("<u4", &d_data);
8381

8482
MetaInfo info;
8583

86-
info.SetInfo("group", str.c_str());
87-
auto const& h_group = info.group_ptr_;
88-
ASSERT_EQ(h_group.size(), d_data.size() + 1);
84+
thrust::device_vector<uint32_t> d_uint;
85+
std::string uint_str = PrepareData<uint32_t>("<u4", &d_uint);
86+
info.SetInfo("group", uint_str.c_str());
87+
auto& h_group = info.group_ptr_;
88+
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
8989
for (size_t i = 1; i < h_group.size(); ++i) {
90-
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
90+
ASSERT_EQ(h_group[i], d_uint[i - 1] + h_group[i - 1]) << "i: " << i;
9191
}
92+
93+
thrust::device_vector<int64_t> d_int64;
94+
std::string int_str = PrepareData<int64_t>("<i8", &d_int64);
95+
info = MetaInfo();
96+
info.SetInfo("group", int_str.c_str());
97+
h_group = info.group_ptr_;
98+
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
99+
for (size_t i = 1; i < h_group.size(); ++i) {
100+
ASSERT_EQ(h_group[i], d_uint[i - 1] + h_group[i - 1]) << "i: " << i;
101+
}
102+
103+
// Incorrect type
104+
thrust::device_vector<float> d_float;
105+
std::string float_str = PrepareData<float>("<f4", &d_float);
106+
info = MetaInfo();
107+
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
92108
}
93109
} // namespace xgboost

tests/python-gpu/test_gpu_linear.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,50 @@
11
import sys
2-
import pytest
3-
import unittest
2+
from hypothesis import strategies, given, settings, assume
3+
import xgboost as xgb
4+
sys.path.append("tests/python")
5+
import testing as tm
46

5-
sys.path.append('tests/python/')
6-
import test_linear # noqa: E402
7-
import testing as tm # noqa: E402
87

8+
parameter_strategy = strategies.fixed_dictionaries({
9+
'booster': strategies.just('gblinear'),
10+
'eta': strategies.floats(0.01, 0.25),
11+
'tolerance': strategies.floats(1e-5, 1e-2),
12+
'nthread': strategies.integers(1, 4),
13+
'feature_selector': strategies.sampled_from(['cyclic', 'shuffle',
14+
'greedy', 'thrifty']),
15+
'top_k': strategies.integers(1, 10),
16+
})
917

10-
class TestGPULinear(unittest.TestCase):
11-
datasets = ["Boston", "Digits", "Cancer", "Sparse regression"]
12-
common_param = {
13-
'booster': ['gblinear'],
14-
'updater': ['gpu_coord_descent'],
15-
'eta': [0.5],
16-
'top_k': [10],
17-
'tolerance': [1e-5],
18-
'alpha': [.1],
19-
'lambda': [0.005],
20-
'coordinate_selection': ['cyclic', 'random', 'greedy']}
18+
def train_result(param, dmat, num_rounds):
19+
result = {}
20+
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
21+
evals_result=result)
22+
return result
2123

22-
@pytest.mark.skipif(**tm.no_sklearn())
23-
def test_gpu_coordinate(self):
24-
parameters = self.common_param.copy()
25-
parameters['gpu_id'] = [0]
26-
for param in test_linear.parameter_combinations(parameters):
27-
results = test_linear.run_suite(
28-
param, 100, self.datasets, scale_features=True)
29-
test_linear.assert_regression_result(results, 1e-2)
30-
test_linear.assert_classification_result(results)
24+
25+
class TestGPULinear:
26+
@given(parameter_strategy, strategies.integers(10, 50),
27+
tm.dataset_strategy)
28+
@settings(deadline=None)
29+
def test_gpu_coordinate(self, param, num_rounds, dataset):
30+
assume(len(dataset.y) > 0)
31+
param['updater'] = 'gpu_coord_descent'
32+
param = dataset.set_params(param)
33+
result = train_result(param, dataset.get_dmat(), num_rounds)['train'][dataset.metric]
34+
assert tm.non_increasing(result)
35+
36+
# Loss is not guaranteed to always decrease because of regularisation parameters
37+
# We test a weaker condition that the loss has not increased between the first and last
38+
# iteration
39+
@given(parameter_strategy, strategies.integers(10, 50),
40+
tm.dataset_strategy, strategies.floats(1e-5, 2.0),
41+
strategies.floats(1e-5, 2.0))
42+
@settings(deadline=None)
43+
def test_gpu_coordinate_regularised(self, param, num_rounds, dataset, alpha, lambd):
44+
assume(len(dataset.y) > 0)
45+
param['updater'] = 'gpu_coord_descent'
46+
param['alpha'] = alpha
47+
param['lambda'] = lambd
48+
param = dataset.set_params(param)
49+
result = train_result(param, dataset.get_dmat(), num_rounds)['train'][dataset.metric]
50+
assert tm.non_increasing([result[0], result[-1]])

tests/python-gpu/test_gpu_pickling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
import numpy as np
55
import subprocess
66
import os
7+
import sys
78
import json
89
import pytest
910

11+
sys.path.append("tests/python")
12+
import testing as tm
13+
1014
import xgboost as xgb
1115
from xgboost import XGBClassifier
1216

@@ -90,7 +94,6 @@ def test_wrap_gpu_id(self):
9094
)
9195
status = subprocess.call(args, env=env)
9296
assert status == 0
93-
9497
os.remove(model_path)
9598

9699
def test_pickled_predictor(self):

tests/python-gpu/test_gpu_prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,10 @@ def test_inplace_predict_cudf(self):
158158
rows = 1000
159159
cols = 10
160160
rng = np.random.RandomState(1994)
161+
cp.cuda.runtime.setDevice(0)
161162
X = rng.randn(rows, cols)
162163
X = pd.DataFrame(X)
163164
y = rng.randn(rows)
164-
165165
X = cudf.from_pandas(X)
166166

167167
dtrain = xgb.DMatrix(X, y)

0 commit comments

Comments
 (0)