Skip to content
Snippets Groups Projects
Commit c0736f6f authored by Ankur Dave's avatar Ankur Dave
Browse files

Add Bagel, an implementation of Pregel on Spark

parent 94ba95bc
No related branches found
No related tags found
No related merge requests found
package bagel
import spark._
import spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
/**
* Runs a Pregel job on the given vertices, running the specified
* compute function on each vertex in every superstep. Before
* beginning the first superstep, sends the given messages to their
* destination vertices. In the join stage, launches splits
* separate tasks (where splits is manually specified to work
* around a bug in Spark).
*
* Halts when no more messages are being sent between vertices, and
* all vertices have voted to halt by setting their state to
* Inactive.
*/
def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = {
println("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
// Bring together vertices and messages
println("Joining vertices and messages...")
val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits)
println("verts.splits.size = " + verts.splits.size)
println("combinedMsgs.splits.size = " + combinedMsgs.splits.size)
println("verts.partitioner = " + verts.partitioner)
println("combinedMsgs.partitioner = " + combinedMsgs.partitioner)
val joined = verts.groupWith(combinedMsgs)
println("joined.splits.size = " + joined.splits.size)
println("joined.partitioner = " + joined.partitioner)
//val joined = graph.groupByKeyAsymmetrical(messageCombiner, defaultCombined, mergeCombined, splits)
println("Done joining vertices and messages.")
// Run compute on each vertex
println("Running compute on each vertex...")
var messageCount = sc.accumulator(0)
var activeVertexCount = sc.accumulator(0)
val processed = joined.flatMapValues {
case (Seq(), _) => None
case (Seq(v), Seq(comb)) =>
val (newVertex, newMessages) = compute(v, comb, superstep)
messageCount += newMessages.size
if (newVertex.active)
activeVertexCount += 1
Some((newVertex, newMessages))
//val result = ArrayBuffer[(String, Either[V, M])]((newVertex.id, Left(newVertex)))
//result ++= newMessages.map(m => (m.targetId, Right(m)))
case (Seq(v), Seq()) =>
val (newVertex, newMessages) = compute(v, defaultCombined(), superstep)
messageCount += newMessages.size
if (newVertex.active)
activeVertexCount += 1
Some((newVertex, newMessages))
}.cache
//MATEI: Added this
processed.foreach(x => {})
println("Done running compute on each vertex.")
println("Checking stopping condition...")
val stop = messageCount.value == 0 && activeVertexCount.value == 0
val timeTaken = System.currentTimeMillis - startTime
println("Superstep %d took %d s".format(superstep, timeTaken / 1000))
val newVerts = processed.mapValues(_._1)
val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m)))
if (superstep >= 10)
processed.map { _._2._1 }
else
run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, superstep + 1)(compute)
}
}
/**
* Represents a Pregel vertex. Must be subclassed to store state
* along with each vertex. Must be annotated with @serializable.
*/
trait Vertex {
def id: String
def active: Boolean
}
/**
* Represents a Pregel message to a target vertex. Must be
* subclassed to contain a payload. Must be annotated with @serializable.
*/
trait Message {
def targetId: String
}
/**
* Represents a directed edge between two vertices. Owned by the
* source vertex, and contains the ID of the target vertex. Must
* be subclassed to store state along with each edge. Must be annotated with @serializable.
*/
trait Edge {
def targetId: String
}
package bagel
import spark._
import spark.SparkContext._
import scala.math.min
/*
object ShortestPath {
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
"<numSplits> <host>")
System.exit(-1)
}
val graphFile = args(0)
val startVertex = args(1)
val numSplits = args(2).toInt
val host = args(3)
val sc = new SparkContext(host, "ShortestPath")
// Parse the graph data from a file into two RDDs, vertices and messages
val lines =
(sc.textFile(graphFile)
.filter(!_.matches("^\\s*#.*"))
.map(line => line.split("\t")))
val vertices: RDD[(String, Either[SPVertex, SPMessage])] =
(lines.groupBy(line => line(0))
.map {
case (vertexId, lines) => {
val outEdges = lines.collect {
case Array(_, targetId, edgeValue) =>
new SPEdge(targetId, edgeValue.toInt)
}
(vertexId, Left[SPVertex, SPMessage](new SPVertex(vertexId, Int.MaxValue, outEdges, true)))
}
})
val messages: RDD[(String, Either[SPVertex, SPMessage])] =
(lines.filter(_.length == 2)
.map {
case Array(vertexId, messageValue) =>
(vertexId, Right[SPVertex, SPMessage](new SPMessage(vertexId, messageValue.toInt)))
})
val graph: RDD[(String, Either[SPVertex, SPMessage])] = vertices ++ messages
System.err.println("Read "+vertices.count()+" vertices and "+
messages.count()+" messages.")
// Do the computation
def messageCombiner(minSoFar: Int, message: SPMessage): Int =
min(minSoFar, message.value)
val result = Pregel.run(sc, graph, numSplits, messageCombiner, () => Int.MaxValue, min _) {
(self: SPVertex, messageMinValue: Int, superstep: Int) =>
val newValue = min(self.value, messageMinValue)
val outbox =
if (newValue != self.value)
self.outEdges.map(edge =>
new SPMessage(edge.targetId, newValue + edge.value))
else
List()
(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
}
// Print the result
System.err.println("Shortest path from "+startVertex+" to all vertices:")
val shortest = result.map(vertex =>
"%s\t%s\n".format(vertex.id, vertex.value match {
case x if x == Int.MaxValue => "inf"
case x => x
})).collect.mkString
println(shortest)
}
}
@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message
*/
package bagel
import spark._
import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
import scala.xml.{XML,NodeSeq}
import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
import com.esotericsoftware.kryo._
object WikipediaPageRank {
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: PageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
System.exit(-1)
}
System.setProperty("spark.serialization", "spark.KryoSerialization")
System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
val inputFile = args(0)
val threshold = args(1).toDouble
val numSplits = args(2).toInt
val host = args(3)
val noCombiner = args.length > 4 && args(4).nonEmpty
val sc = new SparkContext(host, "WikipediaPageRank")
// Parse the Wikipedia page data into a graph
val input = sc.textFile(inputFile)
println("Counting vertices...")
val numVertices = input.count()
println("Done counting vertices.")
println("Parsing input file...")
val vertices: RDD[(String, PRVertex)] = input.map(line => {
val fields = line.split("\t")
val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
val links =
if (body == "\\N")
NodeSeq.Empty
else
try {
XML.loadString(body) \\ "link" \ "target"
} catch {
case e: org.xml.sax.SAXParseException =>
System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
NodeSeq.Empty
}
val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
val id = new String(title)
(id, (new PRVertex(id, 1.0 / numVertices, outEdges, true)))
})
val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache
println("Done parsing input file.")
println("Input file had "+graph.count+" vertices.")
// Do the computation
val epsilon = 0.01 / numVertices
val result =
if (noCombiner) {
val messages = sc.parallelize(List[(String, PRMessage)]())
Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon))
} else {
val messages = sc.parallelize(List[(String, PRMessage)]())
Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon))
}
// Print the result
System.err.println("Articles with PageRank >= "+threshold+":")
val top = result.filter(_.value >= threshold).map(vertex =>
"%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
println(top)
}
object Combiner {
def messageCombiner(minSoFar: Double, message: PRMessage): Double =
minSoFar + message.value
def mergeCombined(a: Double, b: Double) = a + b
def defaultCombined(): Double = 0.0
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue =
if (messageSum != 0)
0.15 / numVertices + 0.85 * messageSum
else
self.value
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
val outbox =
if (!terminate)
self.outEdges.map(edge =>
new PRMessage(edge.targetId, newValue / self.outEdges.size))
else
ArrayBuffer[PRMessage]()
(new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
}
}
object NoCombiner {
def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
messagesSoFar += message
def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
a ++= b
def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]()
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) =
Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep)
}
}
@serializable class PRVertex() extends Vertex with Externalizable {
var id: String = _
var value: Double = _
var outEdges: ArrayBuffer[PREdge] = _
var active: Boolean = true
def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) {
this()
this.id = id
this.value = value
this.outEdges = outEdges
this.active = active
}
def writeExternal(out: ObjectOutput) {
out.writeUTF(id)
out.writeDouble(value)
out.writeInt(outEdges.length)
for (e <- outEdges)
out.writeUTF(e.targetId)
out.writeBoolean(active)
}
def readExternal(in: ObjectInput) {
id = in.readUTF()
value = in.readDouble()
val numEdges = in.readInt()
outEdges = new ArrayBuffer[PREdge](numEdges)
for (i <- 0 until numEdges) {
outEdges += new PREdge(in.readUTF())
}
active = in.readBoolean()
}
}
@serializable class PRMessage() extends Message with Externalizable {
var targetId: String = _
var value: Double = _
def this(targetId: String, value: Double) {
this()
this.targetId = targetId
this.value = value
}
def writeExternal(out: ObjectOutput) {
out.writeUTF(targetId)
out.writeDouble(value)
}
def readExternal(in: ObjectInput) {
targetId = in.readUTF()
value = in.readDouble()
}
}
@serializable class PREdge() extends Edge with Externalizable {
var targetId: String = _
def this(targetId: String) {
this()
this.targetId = targetId
}
def writeExternal(out: ObjectOutput) {
out.writeUTF(targetId)
}
def readExternal(in: ObjectInput) {
targetId = in.readUTF()
}
}
class PRKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
kryo.register(classOf[PRVertex])
kryo.register(classOf[PRMessage])
kryo.register(classOf[PREdge])
}
}
...@@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject ...@@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject
lazy val examples = lazy val examples =
project("examples", "Spark Examples", new ExamplesProject(_), core) project("examples", "Spark Examples", new ExamplesProject(_), core)
lazy val bagel = project("bagel", "Bagel", core)
class CoreProject(info: ProjectInfo) class CoreProject(info: ProjectInfo)
extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport
{} {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment