Skip to content
Snippets Groups Projects
Commit c18fa3eb authored by Ankur Dave's avatar Ankur Dave
Browse files

Package combiner functions into a trait

parent 1c8ca0eb
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ package bagel
import spark._
import spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
......@@ -24,9 +24,7 @@ object Pregel extends Logging {
sc: SparkContext,
verts: RDD[(String, V)],
msgs: RDD[(String, M)],
createCombiner: M => C,
mergeMsg: (C, M) => C,
mergeCombiners: (C, C) => C,
combiner: Combiner[M, C],
numSplits: Int,
superstep: Int = 0
)(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
......@@ -35,7 +33,7 @@ object Pregel extends Logging {
val startTime = System.currentTimeMillis
// Bring together vertices and messages
val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits)
val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs)
// Run compute on each vertex
......@@ -72,17 +70,24 @@ object Pregel extends Logging {
val newMsgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute)
run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
}
}
}
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}
def defaultCreateCombiner[M <: Message](msg: M): ArrayBuffer[M] = ArrayBuffer(msg)
def defaultMergeMsg[M <: Message](combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
def createCombiner(msg: M): ArrayBuffer[M] =
ArrayBuffer(msg)
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
combiner += msg
def defaultMergeCombiners[M <: Message](a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
a ++= b
def defaultCompute[V <: Vertex, M <: Message](self: V, msgs: Option[ArrayBuffer[M]], superstep: Int): (V, Iterable[M]) =
(self, List())
}
/**
......
......@@ -49,12 +49,7 @@ object ShortestPath {
messages.count()+" messages.")
// Do the computation
def createCombiner(message: SPMessage): Int = message.value
def mergeMsg(combiner: Int, message: SPMessage): Int =
min(combiner, message.value)
def mergeCombiners(a: Int, b: Int): Int = min(a, b)
val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
......@@ -82,6 +77,15 @@ object ShortestPath {
}
}
object MinCombiner extends Combiner[SPMessage, Int] {
def createCombiner(msg: SPMessage): Int =
msg.value
def mergeMsg(combiner: Int, msg: SPMessage): Int =
min(combiner, msg.value)
def mergeCombiners(a: Int, b: Int): Int =
min(a, b)
}
@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message
......@@ -60,9 +60,9 @@ object WikipediaPageRank {
val messages = sc.parallelize(List[(String, PRMessage)]())
val result =
if (noCombiner) {
Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon))
Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon))
} else {
Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon))
Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
}
// Print the result
......@@ -71,53 +71,44 @@ object WikipediaPageRank {
"%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
println(top)
}
}
object Combiner {
def createCombiner(message: PRMessage): Double = message.value
def mergeMsg(combiner: Double, message: PRMessage): Double =
combiner + message.value
def mergeCombiners(a: Double, b: Double) = a + b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue = messageSum match {
case Some(msgSum) if msgSum != 0 =>
0.15 / numVertices + 0.85 * msgSum
case _ => self.value
}
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
val outbox =
if (!terminate)
self.outEdges.map(edge =>
new PRMessage(edge.targetId, newValue / self.outEdges.size))
else
ArrayBuffer[PRMessage]()
(new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
object PRCombiner extends Combiner[PRMessage, Double] {
def createCombiner(msg: PRMessage): Double =
msg.value
def mergeMsg(combiner: Double, msg: PRMessage): Double =
combiner + msg.value
def mergeCombiners(a: Double, b: Double): Double =
a + b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue = messageSum match {
case Some(msgSum) if msgSum != 0 =>
0.15 / numVertices + 0.85 * msgSum
case _ => self.value
}
}
object NoCombiner {
def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] =
ArrayBuffer(message)
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
combiner += message
val outbox =
if (!terminate)
self.outEdges.map(edge =>
new PRMessage(edge.targetId, newValue / self.outEdges.size))
else
ArrayBuffer[PRMessage]()
def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
a ++= b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
Combiner.compute(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
(new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
}
}
object PRNoCombiner extends DefaultCombiner[PRMessage] {
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
PRCombiner.compute(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
}
@serializable class PRVertex() extends Vertex {
var id: String = _
var value: Double = _
......
......@@ -20,10 +20,7 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
val result =
Pregel.run(sc, verts, msgs,
Pregel.defaultCreateCombiner[TestMessage],
Pregel.defaultMergeMsg[TestMessage],
Pregel.defaultMergeCombiners[TestMessage], 1) {
Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
(new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
......@@ -37,10 +34,7 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
val result =
Pregel.run(sc, verts, msgs,
Pregel.defaultCreateCombiner[TestMessage],
Pregel.defaultMergeMsg[TestMessage],
Pregel.defaultMergeCombiners[TestMessage], 1) {
Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment