diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 0fc1e4df6813cff38f4abbe101b7334d0581c932..377d9d6bd5e72ac18af49fd5b4662d160d7baa84 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -18,11 +18,11 @@ package org.apache.spark.graphx import scala.reflect.ClassTag - import org.apache.spark.SparkContext._ import org.apache.spark.SparkException import org.apache.spark.graphx.lib._ import org.apache.spark.rdd.RDD +import scala.util.Random /** * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the @@ -137,6 +137,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } } // end of collectNeighbor + /** + * Returns an RDD that contains for each vertex v its local edges, + * i.e., the edges that are incident on v, in the user-specified direction. + * Warning: note that singleton vertices, those with no edges in the given + * direction will not be part of the return value. + * + * @note This function could be highly inefficient on power-law + * graphs where high degree vertices may force a large amount of + * information to be collected to a single location. + * + * @param edgeDirection the direction along which to collect + * the local edges of vertices + * + * @return the local edges for each vertex + */ + def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = { + edgeDirection match { + case EdgeDirection.Either => + graph.mapReduceTriplets[Array[Edge[ED]]]( + edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), + (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), + (a, b) => a ++ b) + case EdgeDirection.In => + graph.mapReduceTriplets[Array[Edge[ED]]]( + edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), + (a, b) => a ++ b) + case EdgeDirection.Out => + graph.mapReduceTriplets[Array[Edge[ED]]]( + edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), + (a, b) => a ++ b) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + } + /** * Join the vertices with an RDD and then apply a function from the * the vertex and RDD entry to a new vertex value. The input table @@ -209,6 +245,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali graph.mask(preprocess(graph).subgraph(epred, vpred)) } + /** + * Picks a random vertex from the graph and returns its ID. + */ + def pickRandomVertex(): VertexId = { + val probability = 50 / graph.numVertices + var found = false + var retVal: VertexId = null.asInstanceOf[VertexId] + while (!found) { + val selectedVertices = graph.vertices.flatMap { vidVvals => + if (Random.nextDouble() < probability) { Some(vidVvals._1) } + else { None } + } + if (selectedVertices.count > 1) { + found = true + val collectedVertices = selectedVertices.collect() + retVal = collectedVertices(Random.nextInt(collectedVertices.size)) + } + } + retVal + } + /** * Execute a Pregel-like iterative vertex-parallel abstraction. The * user-defined vertex-program `vprog` is executed in parallel on diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index bc2ad5677f8063ba646aac0b98fcd39c42021d60..6386306c048fc4a91a158d8ff9667808501c4fe7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test("collectNeighborIds") { withSpark { sc => - val chain = (0 until 100).map(x => (x, (x+1)%100) ) - val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) } - val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val graph = getCycleGraph(sc, 100) val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache() - assert(nbrs.count === chain.size) + assert(nbrs.count === 100) assert(graph.numVertices === nbrs.count) nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) } - nbrs.collect.foreach { case (vid, nbrs) => - val s = nbrs.toSet - assert(s.contains((vid + 1) % 100)) - assert(s.contains(if (vid > 0) vid - 1 else 99 )) + nbrs.collect.foreach { + case (vid, nbrs) => + val s = nbrs.toSet + assert(s.contains((vid + 1) % 100)) + assert(s.contains(if (vid > 0) vid - 1 else 99)) } } } - + test ("filter") { withSpark { sc => val n = 5 @@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { } } + test("collectEdgesCycleDirectionOut") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.Out).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.dstId) + assert(edgeDstIds.contains((vid + 1) % 100)) + } + } + } + + test("collectEdgesCycleDirectionIn") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.In).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeSrcIds = s.map(e => e.srcId) + assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99)) + } + } + } + + test("collectEdgesCycleDirectionEither") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.Either).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId) + assert(edgeIds.contains((vid + 1) % 100)) + assert(edgeIds.contains(if (vid > 0) vid - 1 else 99)) + } + } + } + + test("collectEdgesChainDirectionOut") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.Out).cache() + assert(edges.count == 49) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.dstId) + assert(edgeDstIds.contains(vid + 1)) + } + } + } + + test("collectEdgesChainDirectionIn") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.In).cache() + // We expect only 49 because collectEdges does not return vertices that do + // not have any edges in the specified direction. + assert(edges.count == 49) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.srcId) + assert(edgeDstIds.contains((vid - 1) % 100)) + } + } + } + + test("collectEdgesChainDirectionEither") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.Either).cache() + // We expect only 49 because collectEdges does not return vertices that do + // not have any edges in the specified direction. + assert(edges.count === 50) + edges.collect.foreach { + case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2) + else assert(edges.size == 1) + } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId) + if (vid == 0) { assert(edgeIds.contains(1)) } + else if (vid == 49) { assert(edgeIds.contains(48)) } + else { + assert(edgeIds.contains(vid + 1)) + assert(edgeIds.contains(vid - 1)) + } + } + } + } + + private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = { + val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices)) + getGraphFromSeq(sc, cycle) + } + + private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = { + val chain = (0 until numVertices - 1).map(x => (x, (x + 1))) + getGraphFromSeq(sc, chain) + } + + private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = { + val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) } + Graph.fromEdgeTuples(rawEdges, 1.0).cache() + } }