diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala
index 67bc582fd1659f0b41f6ec3636af05aade08ade7..e3b6d0c70ac9e57f7e6e353ba263232457ca2129 100644
--- a/bagel/src/main/scala/bagel/Pregel.scala
+++ b/bagel/src/main/scala/bagel/Pregel.scala
@@ -2,7 +2,7 @@ package bagel
 
 import spark._
 import spark.SparkContext._
-import scala.collection.mutable.HashMap
+
 import scala.collection.mutable.ArrayBuffer
 
 object Pregel extends Logging {
@@ -24,9 +24,7 @@ object Pregel extends Logging {
     sc: SparkContext,
     verts: RDD[(String, V)],
     msgs: RDD[(String, M)],
-    createCombiner: M => C,
-    mergeMsg: (C, M) => C,
-    mergeCombiners: (C, C) => C,
+    combiner: Combiner[M, C],
     numSplits: Int,
     superstep: Int = 0
   )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
@@ -35,7 +33,7 @@ object Pregel extends Logging {
     val startTime = System.currentTimeMillis
 
     // Bring together vertices and messages
-    val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits)
+    val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
     val grouped = verts.groupWith(combinedMsgs)
 
     // Run compute on each vertex
@@ -72,17 +70,24 @@ object Pregel extends Logging {
       val newMsgs = processed.flatMap {
         case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
       }
-      run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute)
+      run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
     }
   }
+}
+
+trait Combiner[M, C] {
+  def createCombiner(msg: M): C
+  def mergeMsg(combiner: C, msg: M): C
+  def mergeCombiners(a: C, b: C): C
+}
 
-  def defaultCreateCombiner[M <: Message](msg: M): ArrayBuffer[M] = ArrayBuffer(msg)
-  def defaultMergeMsg[M <: Message](combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
+@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
+  def createCombiner(msg: M): ArrayBuffer[M] =
+    ArrayBuffer(msg)
+  def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
     combiner += msg
-  def defaultMergeCombiners[M <: Message](a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
+  def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
     a ++= b
-  def defaultCompute[V <: Vertex, M <: Message](self: V, msgs: Option[ArrayBuffer[M]], superstep: Int): (V, Iterable[M]) =
-    (self, List())
 }
 
 /**
diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala
index 6699f58a31c76c806c00dd231326d64ab93b9ea9..3fd2f393348f0bf667b2c3c93835724adc056708 100644
--- a/bagel/src/main/scala/bagel/ShortestPath.scala
+++ b/bagel/src/main/scala/bagel/ShortestPath.scala
@@ -49,12 +49,7 @@ object ShortestPath {
                        messages.count()+" messages.")
 
     // Do the computation
-    def createCombiner(message: SPMessage): Int = message.value
-    def mergeMsg(combiner: Int, message: SPMessage): Int =
-      min(combiner, message.value)
-    def mergeCombiners(a: Int, b: Int): Int = min(a, b)
-
-    val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
+    val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
       (self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
         val newValue = messageMinValue match {
           case Some(minVal) => min(self.value, minVal)
@@ -82,6 +77,15 @@ object ShortestPath {
   }
 }
 
+object MinCombiner extends Combiner[SPMessage, Int] {
+  def createCombiner(msg: SPMessage): Int =
+    msg.value
+  def mergeMsg(combiner: Int, msg: SPMessage): Int =
+    min(combiner, msg.value)
+  def mergeCombiners(a: Int, b: Int): Int =
+    min(a, b)
+}
+
 @serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
 @serializable class SPEdge(val targetId: String, val value: Int) extends Edge
 @serializable class SPMessage(val targetId: String, val value: Int) extends Message
diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala
index f6aeff0782bef3a0a642db882ec84dd25162d909..994cea8ec3951ddceb784caa994edb6f3c7725bf 100644
--- a/bagel/src/main/scala/bagel/WikipediaPageRank.scala
+++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala
@@ -60,9 +60,9 @@ object WikipediaPageRank {
     val messages = sc.parallelize(List[(String, PRMessage)]())
     val result =
       if (noCombiner) {
-        Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon)) 
+        Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon)) 
       } else {
-        Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon))
+        Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
       }
 
     // Print the result
@@ -71,53 +71,44 @@ object WikipediaPageRank {
       "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
     println(top)
   }
+}
 
-  object Combiner {
-    def createCombiner(message: PRMessage): Double = message.value
-
-    def mergeMsg(combiner: Double, message: PRMessage): Double =
-      combiner + message.value
-
-    def mergeCombiners(a: Double, b: Double) = a + b
-
-    def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
-      val newValue = messageSum match {
-        case Some(msgSum) if msgSum != 0 =>
-          0.15 / numVertices + 0.85 * msgSum
-        case _ => self.value
-      }
-
-      val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
-
-      val outbox =
-        if (!terminate)
-          self.outEdges.map(edge =>
-            new PRMessage(edge.targetId, newValue / self.outEdges.size))
-        else
-          ArrayBuffer[PRMessage]()
-
-      (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
+object PRCombiner extends Combiner[PRMessage, Double] {
+  def createCombiner(msg: PRMessage): Double =
+    msg.value
+  def mergeMsg(combiner: Double, msg: PRMessage): Double =
+    combiner + msg.value
+  def mergeCombiners(a: Double, b: Double): Double =
+    a + b
+
+  def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
+    val newValue = messageSum match {
+      case Some(msgSum) if msgSum != 0 =>
+        0.15 / numVertices + 0.85 * msgSum
+      case _ => self.value
     }
-  }
 
-  object NoCombiner {
-    def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] =
-      ArrayBuffer(message)
+    val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
 
-    def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
-      combiner += message
+    val outbox =
+      if (!terminate)
+        self.outEdges.map(edge =>
+          new PRMessage(edge.targetId, newValue / self.outEdges.size))
+      else
+        ArrayBuffer[PRMessage]()
 
-    def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
-      a ++= b
-
-    def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
-      Combiner.compute(numVertices, epsilon)(self, messages match {
-        case Some(msgs) => Some(msgs.map(_.value).sum)
-        case None => None
-      }, superstep)
+    (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
   }
 }
 
+object PRNoCombiner extends DefaultCombiner[PRMessage] {
+  def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
+    PRCombiner.compute(numVertices, epsilon)(self, messages match {
+      case Some(msgs) => Some(msgs.map(_.value).sum)
+      case None => None
+    }, superstep)
+}
+
 @serializable class PRVertex() extends Vertex {
   var id: String = _
   var value: Double = _
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 72aecb7fd8aea7aec0682e124200020a21fa99e7..29f5f0c35827519c9175e654ef39283639345bf8 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -20,10 +20,7 @@ class BagelSuite extends FunSuite with Assertions {
     val msgs = sc.parallelize(Array[(String, TestMessage)]())
     val numSupersteps = 5
     val result =
-      Pregel.run(sc, verts, msgs,
-                 Pregel.defaultCreateCombiner[TestMessage],
-                 Pregel.defaultMergeMsg[TestMessage],
-                 Pregel.defaultMergeCombiners[TestMessage], 1) {
+      Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
       (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
         (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
     }
@@ -37,10 +34,7 @@ class BagelSuite extends FunSuite with Assertions {
     val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
     val numSupersteps = 5
     val result =
-      Pregel.run(sc, verts, msgs,
-                 Pregel.defaultCreateCombiner[TestMessage],
-                 Pregel.defaultMergeMsg[TestMessage],
-                 Pregel.defaultMergeCombiners[TestMessage], 1) {
+      Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
       (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
         val msgsOut =
           msgs match {