Skip to content

Commit bbf8dc1

Browse files
committed
Add tracker info to training api.
1 parent 1a08012 commit bbf8dc1

File tree

5 files changed

+64
-12
lines changed

5 files changed

+64
-12
lines changed

demo/dask/cpu_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from dask.distributed import Client
44
from dask.distributed import LocalCluster
55
from dask import array as da
6+
import logging
67

78

89
def main(client):
910
# generate some random data for demonstration
11+
logging.basicConfig(level=logging.INFO)
1012
m = 100000
1113
n = 100
1214
X = da.random.random(size=(m, n), chunks=100)

doc/parameter.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ Parameters for Tree Booster
226226
See tutorial for more information
227227

228228
Additional parameters for `hist` and 'gpu_hist' tree method
229-
================================================
229+
===========================================================
230230

231231
* ``single_precision_histogram``, [default=``false``]
232232

doc/tutorials/saving_model.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ or in R:
121121
122122
Will print out something similiar to (not actual output as it's too long for demonstration):
123123

124-
.. code-block:: json
124+
.. code-block:: js
125125
126126
{
127127
"Learner": {
@@ -201,7 +201,7 @@ Difference between saving model and dumping model
201201
XGBoost has a function called ``dump_model`` in Booster object, which lets you to export
202202
the model in a readable format like ``text``, ``json`` or ``dot`` (graphviz). The primary
203203
use case for it is for model interpretation or visualization, and is not supposed to be
204-
loaded back to XGBoost. The JSON version has a `schema
204+
loaded back to XGBoost. The JSON version has a `Schema
205205
<https://github.com/dmlc/xgboost/blob/master/doc/dump.schema>`_. See next section for
206206
more info.
207207

python-package/xgboost/dask.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@
4747
LOGGER = logging.getLogger('[xgboost.dask]')
4848

4949

50-
def _start_tracker(host, n_workers):
50+
def _start_tracker(host, port, n_workers):
5151
"""Start Rabit tracker """
5252
env = {'DMLC_NUM_WORKER': n_workers}
53-
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
53+
if port:
54+
rabit_context = RabitTracker(hostIP=host, port=port, port_end=port+1,
55+
nslave=n_workers)
56+
else:
57+
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
5458
env.update(rabit_context.slave_envs())
5559

5660
rabit_context.start(n_workers)
@@ -356,13 +360,21 @@ def get_worker_data_shape(self, worker):
356360
cols = c
357361
return (rows, cols)
358362

359-
360-
def _get_rabit_args(worker_map, client):
363+
from distributed import Client
364+
def _get_rabit_args(worker_map, client: Client, host_ip=None, port=None):
361365
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
362-
host = distributed_comm.get_address_host(client.scheduler.address)
366+
msg = 'Please provide both IP and port'
367+
assert (host_ip and port) or (host_ip is None and port is None), msg
363368

364-
env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
365-
len(worker_map))
369+
if host_ip:
370+
LOGGER.info('Running tracker on: %s, %s', host_ip, str(port))
371+
env = client.run_on_scheduler(_start_tracker, host_ip, port,
372+
len(worker_map))
373+
else:
374+
host = distributed_comm.get_address_host(client.scheduler.address)
375+
LOGGER.info('Running tracker on: %s', host.strip('/:'))
376+
env = client.run_on_scheduler(_start_tracker, host.strip('/:'), port,
377+
len(worker_map))
366378
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
367379
return rabit_args
368380

@@ -373,7 +385,8 @@ def _get_rabit_args(worker_map, client):
373385
# evaluation history is instead returned.
374386

375387

376-
def train(client, params, dtrain, *args, evals=(), **kwargs):
388+
def train(client, params, dtrain, *args, evals=(), tracker_ip=None,
389+
tracker_port=None, **kwargs):
377390
'''Train XGBoost model.
378391
379392
.. versionadded:: 1.0.0
@@ -383,6 +396,19 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
383396
client: dask.distributed.Client
384397
Specify the dask client used for training. Use default client
385398
returned from dask if it's set to None.
399+
400+
tracker_ip:
401+
Address for rabit tracker that runs on dask scheduler. Use
402+
`client.scheduler.address` if unspecified.
403+
404+
.. versionadded:: 1.2.0
405+
406+
tracker_port:
407+
Port for the tracker. Search for available ports automatically if
408+
unspecified.
409+
410+
.. versionadded:: 1.2.0
411+
386412
\\*\\*kwargs:
387413
Other parameters are the same as `xgboost.train` except for
388414
`evals_result`, which is returned as part of function return value
@@ -410,7 +436,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
410436

411437
workers = list(_get_client_workers(client).keys())
412438

413-
rabit_args = _get_rabit_args(workers, client)
439+
rabit_args = _get_rabit_args(workers, client, tracker_ip, tracker_port)
414440

415441
def dispatched_train(worker_addr):
416442
'''Perform training on a single worker.'''

tests/python/test_with_dask.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import dask.dataframe as dd
1616
import dask.array as da
1717
from xgboost.dask import DaskDMatrix
18+
from dask.distributed import comm
1819
except ImportError:
1920
LocalCluster = None
2021
Client = None
@@ -286,3 +287,26 @@ def test_empty_dmatrix_approx():
286287
with Client(cluster) as client:
287288
parameters = {'tree_method': 'approx'}
288289
run_empty_dmatrix(client, parameters)
290+
291+
292+
def test_explicit_rabit_tracker():
293+
with LocalCluster() as cluster:
294+
with Client(cluster) as client:
295+
X, y = generate_array()
296+
host = comm.get_address_host(client.scheduler.address)
297+
port = 9091
298+
dtrain = xgb.dask.DaskDMatrix(client, X, y)
299+
300+
out = xgb.dask.train(client, {'tree_method': 'hist'}, dtrain,
301+
tracker_ip=host, tracker_port=port)
302+
prediction = xgb.dask.predict(client, out, dtrain)
303+
assert prediction.shape[0] == kRows
304+
305+
assert isinstance(prediction, da.Array)
306+
prediction = prediction.compute()
307+
308+
booster = out['booster']
309+
single_node_predt = booster.predict(
310+
xgb.DMatrix(X.compute())
311+
)
312+
np.testing.assert_allclose(prediction, single_node_predt)

0 commit comments

Comments
 (0)