diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 614555a054dfbcfe73f2c9073ddf2c1112d999ec..257e2f3a361154cead41989ad793d5ee58a2bfa7 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -79,30 +79,43 @@ object PageRank extends Logging {
   def run[VD: ClassTag, ED: ClassTag](
       graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] =
   {
-    // Initialize the pagerankGraph with each edge attribute having
+    // Initialize the PageRank graph with each edge attribute having
     // weight 1/outDegree and each vertex with attribute 1.0.
-    val pagerankGraph: Graph[Double, Double] = graph
+    var rankGraph: Graph[Double, Double] = graph
       // Associate the degree with each vertex
       .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
       // Set the weight on the edges based on the degree
       .mapTriplets( e => 1.0 / e.srcAttr )
       // Set the vertex attributes to the initial pagerank values
-      .mapVertices( (id, attr) => 1.0 )
-      .cache()
+      .mapVertices( (id, attr) => resetProb )
 
-    // Define the three functions needed to implement PageRank in the GraphX
-    // version of Pregel
-    def vertexProgram(id: VertexId, attr: Double, msgSum: Double): Double =
-      resetProb + (1.0 - resetProb) * msgSum
-    def sendMessage(edge: EdgeTriplet[Double, Double]) =
-      Iterator((edge.dstId, edge.srcAttr * edge.attr))
-    def messageCombiner(a: Double, b: Double): Double = a + b
-    // The initial message received by all vertices in PageRank
-    val initialMessage = 0.0
+    var iteration = 0
+    var prevRankGraph: Graph[Double, Double] = null
+    while (iteration < numIter) {
+      rankGraph.cache()
 
-    // Execute pregel for a fixed number of iterations.
-    Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)(
-      vertexProgram, sendMessage, messageCombiner)
+      // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and
+      // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation.
+      val rankUpdates = rankGraph.mapReduceTriplets[Double](
+        e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _)
+
+      // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices
+      // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the
+      // edge partitions.
+      prevRankGraph = rankGraph
+      rankGraph = rankGraph.joinVertices(rankUpdates) {
+        (id, oldRank, msgSum) => resetProb + (1.0 - resetProb) * msgSum
+      }.cache()
+
+      rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices
+      logInfo(s"PageRank finished iteration $iteration.")
+      prevRankGraph.vertices.unpersist(false)
+      prevRankGraph.edges.unpersist(false)
+
+      iteration += 1
+    }
+
+    rankGraph
   }
 
   /**