diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 42bb3884c8b8267fed814cb95d5694a3b2af8767..4457525ac87da9818e0a716ba7336719e975aef8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -110,6 +110,9 @@ class DAGScheduler( // resubmit failed stages val POLL_TIMEOUT = 10L + // Warns the user if a stage contains a task with size greater than this value (in KB) + val TASK_SIZE_TO_WARN = 100 + private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { override def preStart() { context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { @@ -430,6 +433,18 @@ class DAGScheduler( handleExecutorLost(execId) case BeginEvent(task, taskInfo) => + for ( + job <- idToActiveJob.get(task.stageId); + stage <- stageIdToStage.get(task.stageId); + stageInfo <- stageToInfos.get(stage) + ) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + stageInfo.emittedTaskSizeWarning = true + logWarning(("Stage %d (%s) contains a task of very large " + + "size (%d KB). The maximum recommended task size is %d KB.").format( + task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN)) + } + } listenerBus.post(SparkListenerTaskStart(task, taskInfo)) case GettingResultEvent(task, taskInfo) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 93599dfdc85fff358d497176bdf92519fbea9426..e9f2198a007e526a237f7190e542761bc191c8af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,4 +33,5 @@ class StageInfo( val name = stage.name val numPartitions = stage.numPartitions val numTasks = stage.numTasks + var emittedTaskSizeWarning = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4bae26f3a6a885c73bd1639d61d226cbd06a5ea2..3c22edd5248f403190a2f543597728d08dba92a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -46,6 +46,8 @@ class TaskInfo( var failed = false + var serializedSize: Int = 0 + def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 4c5eca8537cd62044a27d835bea520691d74931f..8884ea85a34e980796c891a14575f2983216f708 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) + info.serializedSize = serializedTask.limit if (taskAttempts(index).size == 1) taskStarted(task,info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))