Skip to content

Commit cbe4594

Browse files
committed
Drop RABIT single point model recovery.
* Pass rabit params in JVM package. * Implement timeout using poll timeout parameter. * Remove OOB data check.
1 parent 81c37c2 commit cbe4594

File tree

22 files changed

+63
-2864
lines changed

22 files changed

+63
-2864
lines changed

Jenkinsfile

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -321,20 +321,6 @@ def TestPythonGPU(args) {
321321
}
322322
}
323323

324-
def TestCppRabit() {
325-
node(nodeReq) {
326-
unstash name: 'xgboost_rabit_tests'
327-
unstash name: 'srcs'
328-
echo "Test C++, rabit mock on"
329-
def container_type = "cpu"
330-
def docker_binary = "docker"
331-
sh """
332-
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/runxgb.sh xgboost tests/ci_build/approx.conf.in
333-
"""
334-
deleteDir()
335-
}
336-
}
337-
338324
def TestCppGPU(args) {
339325
def nodeReq = 'linux && mgpu'
340326
def artifact_cuda_version = (args.artifact_cuda_version) ?: ref_cuda_ver

R-package/src/Makevars.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread
2222
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
2323
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
2424
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
25-
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
25+
$(PKGROOT)/rabit/src/allreduce_base.o

R-package/src/Makevars.win

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
3434
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
3535
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
3636
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
37-
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
37+
$(PKGROOT)/rabit/src/allreduce_base.o
3838

3939
$(OBJECTS) : xgblib

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ object XGBoost extends Serializable {
577577
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
578578
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
579579
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
580+
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
580581
val sc = trainingData.sparkContext
581582
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
582583
hasGroup, xgbExecParams.numWorkers)
@@ -595,6 +596,8 @@ object XGBoost extends Serializable {
595596
xgbExecParams.timeoutRequestWorkers,
596597
xgbExecParams.numWorkers,
597598
xgbExecParams.killSparkContextOnWorkerFailure)
599+
600+
tracker.getWorkerEnvs().putAll(xgbRabitParams)
598601
val rabitEnv = tracker.getWorkerEnvs
599602
val boostersAndMetrics = if (hasGroup) {
600603
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,

rabit/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ cmake_minimum_required(VERSION 3.3)
22

33
find_package(Threads REQUIRED)
44

5-
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
6-
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
5+
add_library(rabit src/allreduce_base.cc src/engine.cc src/c_api.cc)
6+
add_library(rabit_mock_static src/allreduce_base.cc src/engine_mock.cc src/c_api.cc)
7+
78
target_link_libraries(rabit Threads::Threads dmlc)
89
target_link_libraries(rabit_mock_static Threads::Threads dmlc)
910

rabit/include/rabit/internal/socket.h

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <string>
3131
#include <cstring>
3232
#include <vector>
33+
#include <chrono>
3334
#include <unordered_map>
3435
#include "utils.h"
3536

@@ -95,18 +96,18 @@ namespace utils {
9596
static constexpr int kInvalidSocket = -1;
9697

9798
template <typename PollFD>
98-
int PollImpl(PollFD *pfd, int nfds, int timeout) {
99+
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
99100
#if defined(_WIN32)
100101

101102
#if IS_MINGW()
102103
MingWError();
103104
return -1;
104105
#else
105-
return WSAPoll(pfd, nfds, timeout);
106+
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
106107
#endif // IS_MINGW()
107108

108109
#else
109-
return poll(pfd, nfds, timeout);
110+
return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
110111
#endif // IS_MINGW()
111112
}
112113

@@ -616,32 +617,20 @@ struct PollHelper {
616617
const auto& pfd = fds.find(fd);
617618
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
618619
}
619-
/*!
620-
* \brief wait for exception event on a single descriptor
621-
* \param fd the file descriptor to wait the event for
622-
* \param timeout the timeout counter, can be negative, which means wait until the event happen
623-
* \return 1 if success, 0 if timeout, and -1 if error occurs
624-
*/
625-
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
626-
pollfd pfd;
627-
pfd.fd = fd;
628-
pfd.events = POLLPRI;
629-
return PollImpl(&pfd, 1, timeout);
630-
}
631620

632621
/*!
633622
* \brief peform poll on the set defined, read, write, exception
634623
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
635624
* \return
636625
*/
637-
inline void Poll(long timeout = -1) { // NOLINT(*)
626+
inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
638627
std::vector<pollfd> fdset;
639628
fdset.reserve(fds.size());
640629
for (auto kv : fds) {
641630
fdset.push_back(kv.second);
642631
}
643632
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
644-
if (ret == -1) {
633+
if (ret <= 0) {
645634
Socket::Error("Poll");
646635
} else {
647636
for (auto& pfd : fdset) {

rabit/src/CMakeLists.txt

Lines changed: 0 additions & 31 deletions
This file was deleted.

rabit/src/README.md

Lines changed: 0 additions & 6 deletions
This file was deleted.

rabit/src/allreduce_base.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
77
*/
88
#define NOMINMAX
9+
#include "rabit/base.h"
10+
#include "rabit/internal/rabit-inl.h"
911
#include "allreduce_base.h"
10-
#include <rabit/base.h>
1112

1213
#ifndef _WIN32
1314
#include <netinet/tcp.h>
@@ -208,8 +209,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
208209
rabit_timeout = utils::StringToBool(val);
209210
}
210211
if (!strcmp(name, "rabit_timeout_sec")) {
211-
timeout_sec = atoi(val);
212-
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
212+
timeout_sec = std::chrono::seconds(atoi(val));
213+
utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second");
213214
}
214215
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
215216
if (!strcmp(val, "true")) {
@@ -549,14 +550,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
549550
// finish runing allreduce
550551
if (finished) break;
551552
// select must return
552-
watcher.Poll();
553-
// exception handling
554-
for (int i = 0; i < nlink; ++i) {
555-
// recive OOB message from some link
556-
if (watcher.CheckExcept(links[i].sock)) {
557-
return ReportError(&links[i], kGetExcept);
558-
}
559-
}
553+
watcher.Poll(timeout_sec);
560554
// read data from childs
561555
for (int i = 0; i < nlink; ++i) {
562556
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
@@ -729,7 +723,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
729723
// finish running
730724
if (finished) break;
731725
// select
732-
watcher.Poll();
726+
watcher.Poll(timeout_sec);
733727
// exception handling
734728
for (int i = 0; i < nlink; ++i) {
735729
// recive OOB message from some link
@@ -819,7 +813,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
819813
finished = false;
820814
}
821815
if (finished) break;
822-
watcher.Poll();
816+
watcher.Poll(timeout_sec);
823817
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
824818
size_t size = stop_read - read_ptr;
825819
size_t start = read_ptr % total_size;
@@ -831,7 +825,10 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
831825
read_ptr += static_cast<size_t>(len);
832826
} else {
833827
ReturnType ret = Errno2Return();
834-
if (ret != kSuccess) return ReportError(&next, ret);
828+
if (ret != kSuccess) {
829+
auto err = ReportError(&next, ret);
830+
return err;
831+
}
835832
}
836833
}
837834
if (write_ptr < read_ptr && write_ptr != stop_write) {
@@ -845,7 +842,10 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
845842
write_ptr += static_cast<size_t>(len);
846843
} else {
847844
ReturnType ret = Errno2Return();
848-
if (ret != kSuccess) return ReportError(&prev, ret);
845+
if (ret != kSuccess) {
846+
auto err = ReportError(&prev, ret);
847+
return err;
848+
}
849849
}
850850
}
851851
}
@@ -913,7 +913,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
913913
finished = false;
914914
}
915915
if (finished) break;
916-
watcher.Poll();
916+
watcher.Poll(timeout_sec);
917917
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
918918
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
919919
if (ret != kSuccess) {

rabit/src/allreduce_base.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#ifndef RABIT_ALLREDUCE_BASE_H_
1313
#define RABIT_ALLREDUCE_BASE_H_
1414

15+
#include <functional>
16+
#include <future>
1517
#include <vector>
1618
#include <string>
1719
#include <algorithm>
@@ -35,6 +37,7 @@ class Datatype {
3537
}
3638
namespace rabit {
3739
namespace engine {
40+
3841
/*! \brief implementation of basic Allreduce engine */
3942
class AllreduceBase : public IEngine {
4043
public:
@@ -103,9 +106,11 @@ class AllreduceBase : public IEngine {
103106
size_t slice_end, size_t size_prev_slice,
104107
const char *_file = _FILE, const int _line = _LINE,
105108
const char *_caller = _CALLER) override {
106-
if (world_size == 1 || world_size == -1) return;
107-
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
108-
slice_begin, slice_end, size_prev_slice) == kSuccess,
109+
if (world_size == 1 || world_size == -1) {
110+
return;
111+
}
112+
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin,
113+
slice_end, size_prev_slice) == kSuccess,
109114
"AllgatherRing failed");
110115
}
111116
/*!
@@ -130,8 +135,8 @@ class AllreduceBase : public IEngine {
130135
const char *_caller = _CALLER) override {
131136
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
132137
if (world_size == 1 || world_size == -1) return;
133-
utils::Assert(TryAllreduce(sendrecvbuf_,
134-
type_nbytes, count, reducer) == kSuccess,
138+
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
139+
kSuccess,
135140
"Allreduce failed");
136141
}
137142
/*!
@@ -518,9 +523,9 @@ class AllreduceBase : public IEngine {
518523
//---- data structure related to model ----
519524
// call sequence counter, records how many calls we made so far
520525
// from last call to CheckPoint, LoadCheckPoint
521-
int seq_counter; // NOLINT
526+
int seq_counter{0}; // NOLINT
522527
// version number of model
523-
int version_number; // NOLINT
528+
int version_number {0}; // NOLINT
524529
// whether the job is running in hadoop
525530
bool hadoop_mode; // NOLINT
526531
//---- local data related to link ----
@@ -571,7 +576,7 @@ class AllreduceBase : public IEngine {
571576
// enable detailed logging
572577
bool rabit_debug = false; // NOLINT
573578
// by default, if rabit worker not recover in half an hour exit
574-
int timeout_sec = 1800; // NOLINT
579+
std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT
575580
// flag to enable rabit_timeout
576581
bool rabit_timeout = false; // NOLINT
577582
// Enable TCP node delay

0 commit comments

Comments
 (0)