diff --git a/core/src/main/scala/spark/ui/jobs/StagePage.scala b/core/src/main/scala/spark/ui/jobs/StagePage.scala
index 1b071a91e55584cbd54865aa835b54b5eefb9ca6..884c065deecac9959dcb88704a866968e099e436 100644
--- a/core/src/main/scala/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/spark/ui/jobs/StagePage.scala
@@ -87,7 +87,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
           {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
         Seq("Details")
 
-      val taskTable = listingTable(taskHeaders, taskRow, tasks)
+      val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
 
       // Excludes tasks which failed and have incomplete metrics
       val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined))
@@ -135,7 +135,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
   }
 
 
-  def taskRow(taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = {
+  def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean)
+             (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = {
     def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] =
       trace.map(e => <span style="display:block;">{e.toString}</span>)
     val (info, metrics, exception) = taskData
@@ -154,10 +155,14 @@ private[spark] class StagePage(parent: JobProgressUI) {
       <td>{info.taskLocality}</td>
       <td>{info.hostPort}</td>
       <td>{dateFmt.format(new Date(info.launchTime))}</td>
-      {metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
-        <td>{Utils.memoryBytesToString(s.remoteBytesRead)}</td>}.getOrElse("")}
-      {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
-        <td>{Utils.memoryBytesToString(s.shuffleBytesWritten)}</td>}.getOrElse("")}
+      {if (shuffleRead) {
+        <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
+          Utils.memoryBytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
+      }}
+      {if (shuffleWrite) {
+        <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+          Utils.memoryBytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
+      }}
       <td>{exception.map(e =>
         <span>
           {e.className} ({e.description})<br/>