Skip to content

Commit 2d36d80

Browse files
[SPARK-52459][SQL] Fix a subtle thread-safety issue with SQLAppStatusListener
1 parent 269584b commit 2d36d80

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ class SQLAppStatusListener(
113113
// Record the accumulator IDs and metric types for the stages of this job, so that the code
114114
// that keeps track of the metrics knows which accumulators to look at.
115115
val accumIdsAndType = exec.metricAccumulatorIdToMetricType
116-
if (accumIdsAndType.nonEmpty) {
116+
if (accumIdsAndType.asScala.nonEmpty) {
117117
event.stageInfos.foreach { stage =>
118118
stageMetrics.put(stage.stageId, new LiveStageMetrics(stage.stageId, 0,
119-
stage.numTasks, accumIdsAndType))
119+
stage.numTasks, accumIdsAndType.asScala))
120120
}
121121
}
122122

@@ -207,12 +207,12 @@ class SQLAppStatusListener(
207207
private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
208208
val accumIds = exec.metrics.map(_.accumulatorId).toSet
209209

210-
val metricAggregationMap = new mutable.HashMap[String, (Array[Long], Array[Long]) => String]()
210+
val metricAggregationMap = new ConcurrentHashMap[String, (Array[Long], Array[Long]) => String]()
211211
val metricAggregationMethods = exec.metrics.map { m =>
212212
val optClassName = CustomMetrics.parseV2CustomMetricType(m.metricType)
213213
val metricAggMethod = optClassName.map { className =>
214214
if (metricAggregationMap.contains(className)) {
215-
metricAggregationMap(className)
215+
metricAggregationMap.get(className)
216216
} else {
217217
// Try to initiate custom metric object
218218
try {
@@ -247,41 +247,42 @@ class SQLAppStatusListener(
247247

248248
val maxMetrics = liveStageMetrics.flatMap(_.maxMetricValues())
249249

250-
val allMetrics = new mutable.HashMap[Long, Array[Long]]()
250+
val allMetrics = new ConcurrentHashMap[Long, Array[Long]]()
251251

252-
val maxMetricsFromAllStages = new mutable.HashMap[Long, Array[Long]]()
252+
val maxMetricsFromAllStages = new ConcurrentHashMap[Long, Array[Long]]()
253253

254254
taskMetrics.filter(m => accumIds.contains(m._1)).foreach { case (id, values) =>
255-
val prev = allMetrics.getOrElse(id, null)
255+
val prev = allMetrics.getOrDefault(id, null)
256256
val updated = if (prev != null) {
257257
prev ++ values
258258
} else {
259259
values
260260
}
261-
allMetrics(id) = updated
261+
allMetrics.put(id, updated)
262262
}
263263

264264
// Find the max for each metric id between all stages.
265265
val validMaxMetrics = maxMetrics.filter(m => accumIds.contains(m._1))
266266
validMaxMetrics.foreach { case (id, value, taskId, stageId, attemptId) =>
267-
val updated = maxMetricsFromAllStages.getOrElse(id, Array(value, stageId, attemptId, taskId))
267+
val updated = maxMetricsFromAllStages
268+
.getOrDefault(id, Array(value, stageId, attemptId, taskId))
268269
if (value > updated(0)) {
269270
updated(0) = value
270271
updated(1) = stageId
271272
updated(2) = attemptId
272273
updated(3) = taskId
273274
}
274-
maxMetricsFromAllStages(id) = updated
275+
maxMetricsFromAllStages.put(id, updated)
275276
}
276277

277278
exec.driverAccumUpdates.foreach { case (id, value) =>
278279
if (accumIds.contains(id)) {
279-
val prev = allMetrics.getOrElse(id, null)
280+
val prev = allMetrics.getOrDefault(id, null)
280281
val updated = if (prev != null) {
281282
// If the driver updates same metrics as tasks and has higher value then remove
282283
// that entry from maxMetricsFromAllStage. This would make stringValue function default
283284
// to "driver" that would be displayed on UI.
284-
if (maxMetricsFromAllStages.contains(id) && value > maxMetricsFromAllStages(id)(0)) {
285+
if (maxMetricsFromAllStages.contains(id) && value > maxMetricsFromAllStages.get(id)(0)) {
285286
maxMetricsFromAllStages.remove(id)
286287
}
287288
val _copy = Arrays.copyOf(prev, prev.length + 1)
@@ -290,11 +291,11 @@ class SQLAppStatusListener(
290291
} else {
291292
Array(value)
292293
}
293-
allMetrics(id) = updated
294+
allMetrics.put(id, updated)
294295
}
295296
}
296297

297-
val aggregatedMetrics = allMetrics.map { case (id, values) =>
298+
val aggregatedMetrics = allMetrics.asScala.map { case (id, values) =>
298299
id -> metricAggregationMethods(id)(values, maxMetricsFromAllStages.getOrElse(id,
299300
Array.empty[Long]))
300301
}.toMap
@@ -496,7 +497,7 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity {
496497
// This mapping is shared across all LiveStageMetrics instances associated with
497498
// this LiveExecutionData, helping to reduce memory overhead by avoiding waste
498499
// from separate immutable maps with largely overlapping sets of entries.
499-
val metricAccumulatorIdToMetricType = new mutable.HashMap[Long, String]()
500+
val metricAccumulatorIdToMetricType = new ConcurrentHashMap[Long, String]()
500501
var submissionTime = -1L
501502
var completionTime: Option[Date] = None
502503
var errorMessage: Option[String] = None

0 commit comments

Comments
 (0)