Skip to content

Commit 8d3bae4

Browse files
committed
Add test.
1 parent b383798 commit 8d3bae4

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

tests/python-gpu/test_gpu_basic_models.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,15 @@
1212
class TestGPUBasicModels(unittest.TestCase):
1313
cputest = test_bm.TestModels()
1414

15-
def test_eta_decay_gpu_hist(self):
16-
self.cputest.run_eta_decay('gpu_hist')
17-
18-
def test_deterministic_gpu_hist(self):
19-
kRows = 1000
20-
kCols = 64
21-
kClasses = 4
22-
# Create large values to force rounding.
23-
X = np.random.randn(kRows, kCols) * 1e4
24-
y = np.random.randint(0, kClasses, size=kRows)
25-
15+
def run_cls(self, X, y, deterministic):
2616
cls = xgb.XGBClassifier(tree_method='gpu_hist',
27-
deterministic_histogram=True,
17+
deterministic_histogram=deterministic,
2818
single_precision_histogram=True)
2919
cls.fit(X, y)
3020
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
3121

3222
cls = xgb.XGBClassifier(tree_method='gpu_hist',
33-
deterministic_histogram=True,
23+
deterministic_histogram=deterministic,
3424
single_precision_histogram=True)
3525
cls.fit(X, y)
3626
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
@@ -40,7 +30,24 @@ def test_deterministic_gpu_hist(self):
4030
with open('test_deterministic_gpu_hist-1.json', 'r') as fd:
4131
model_1 = fd.read()
4232

43-
assert hash(model_0) == hash(model_1)
44-
4533
os.remove('test_deterministic_gpu_hist-0.json')
4634
os.remove('test_deterministic_gpu_hist-1.json')
35+
36+
return hash(model_0), hash(model_1)
37+
38+
def test_eta_decay_gpu_hist(self):
39+
self.cputest.run_eta_decay('gpu_hist')
40+
41+
def test_deterministic_gpu_hist(self):
42+
kRows = 1000
43+
kCols = 64
44+
kClasses = 4
45+
# Create large values to force rounding.
46+
X = np.random.randn(kRows, kCols) * 1e4
47+
y = np.random.randint(0, kClasses, size=kRows) * 1e4
48+
49+
model_0, model_1 = self.run_cls(X, y, True)
50+
assert hash(model_0) == hash(model_1)
51+
52+
model_0, model_1 = self.run_cls(X, y, False)
53+
assert hash(model_0) != hash(model_1)

0 commit comments

Comments
 (0)