Skip to content

Commit 4b16bd3

Browse files
committed
cancel job instead of killing SparkContext
This PR changes the default behavior that kills SparkContext. Instead, This PR cancels jobs when coming across task failed. That means the SparkContext is still alive even some exceptions happen.
1 parent 9adb812 commit 4b16bd3

File tree

5 files changed

+26
-29
lines changed

5 files changed

+26
-29
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,8 @@ object XGBoost extends Serializable {
626626
(booster, metrics)
627627
} catch {
628628
case t: Throwable =>
629-
// if the job was aborted due to an exception
629+
// if the job was aborted due to an exception, just throw the exception
630630
logger.error("the job was aborted due to ", t)
631-
trainingData.sparkContext.stop()
632631
throw t
633632
} finally {
634633
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)

jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class SparkParallelismTracker(
7979
def execute[T](body: => T): T = {
8080
if (timeout <= 0) {
8181
logger.info("starting training without setting timeout for waiting for resources")
82-
body
82+
safeExecute(body)
8383
} else {
8484
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources")
8585
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) {
@@ -98,31 +98,30 @@ private[spark] class TaskFailedListener extends SparkListener {
9898
taskEnd.reason match {
9999
case taskEndReason: TaskFailedReason =>
100100
logger.error(s"Training Task Failed during XGBoost Training: " +
101-
s"$taskEndReason, stopping SparkContext")
102-
TaskFailedListener.startedSparkContextKiller()
101+
s"$taskEndReason, cancelling all jobs")
102+
TaskFailedListener.cancelAllJobs()
103103
case _ =>
104104
}
105105
}
106106
}
107107

108108
object TaskFailedListener {
109109

110-
var killerStarted = false
110+
var cancelJobStarted = false
111111

112-
private def startedSparkContextKiller(): Unit = this.synchronized {
113-
if (!killerStarted) {
114-
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
115-
// in a separate thread
116-
val sparkContextKiller = new Thread() {
112+
private def cancelAllJobs(): Unit = this.synchronized {
113+
if (!cancelJobStarted) {
114+
val cancelJob = new Thread() {
117115
override def run(): Unit = {
118116
LiveListenerBus.withinListenerThread.withValue(false) {
119-
SparkContext.getOrCreate().stop()
117+
SparkContext.getOrCreate().cancelAllJobs()
120118
}
121119
}
122120
}
123-
sparkContextKiller.setDaemon(true)
124-
sparkContextKiller.start()
125-
killerStarted = true
121+
cancelJob.setDaemon(true)
122+
cancelJob.start()
123+
cancelJobStarted = true
126124
}
127125
}
126+
128127
}

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
4040
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
4141
}
4242

43-
private def waitForSparkContextShutdown(): Unit = {
43+
private def sparkContextShouldNotShutDown(): Unit = {
4444
var totalWaitedTime = 0L
45-
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
46-
Thread.sleep(10000)
47-
totalWaitedTime += 10000
45+
while (!ss.sparkContext.isStopped && totalWaitedTime <= 10000) {
46+
Thread.sleep(1000)
47+
totalWaitedTime += 1000
4848
}
49-
assert(ss.sparkContext.isStopped === true)
49+
assert(ss.sparkContext.isStopped === false)
5050
}
5151

5252
test("fail training elegantly with unsupported objective function") {
@@ -60,7 +60,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
6060
} catch {
6161
case e: Throwable => // swallow anything
6262
} finally {
63-
waitForSparkContextShutdown()
63+
sparkContextShouldNotShutDown()
6464
}
6565
}
6666

@@ -75,7 +75,7 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
7575
} catch {
7676
case e: Throwable => // swallow anything
7777
} finally {
78-
waitForSparkContextShutdown()
78+
sparkContextShouldNotShutDown()
7979
}
8080
}
8181
}

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
5151
cleanExternalCache(currentSession.sparkContext.appName)
5252
currentSession = null
5353
}
54-
TaskFailedListener.killerStarted = false
54+
TaskFailedListener.cancelJobStarted = false
5555
}
5656
}
5757

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
9191
}
9292

9393
test("test rabit timeout fail handle") {
94-
// disable spark kill listener to verify if rabit_timeout take effect and kill tasks
95-
TaskFailedListener.killerStarted = true
94+
// disable job cancel listener to verify if rabit_timeout take effect and kill tasks
95+
TaskFailedListener.cancelJobStarted = true
9696

9797
val training = buildDataFrame(Classification.train)
9898
// mock rank 0 failure during 8th allreduce synchronization
@@ -109,11 +109,10 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
109109
"rabit_timeout" -> 0))
110110
.fit(training)
111111
} catch {
112-
case e: Throwable => // swallow anything
112+
case e: Throwable => println("----- " + e)// swallow anything
113113
} finally {
114-
// assume all tasks throw exception almost same time
115-
// 100ms should be enough to exhaust all retries
116-
assert(waitAndCheckSparkShutdown(100) == true)
114+
// wait 2s to check if SparkContext is killed
115+
assert(waitAndCheckSparkShutdown(2000) == false)
117116
}
118117
}
119118
}

0 commit comments

Comments
 (0)