Spark Pregel参数说明
Pregel是个强大的基于图的迭代算法,也是Spark中的一个迭代应用aggregateMessage的典型案例,用它可以在图中方便的迭代计算,如最短路径、关键路径、n度关系等。然而对于之前对图计算接触不多的童鞋来说,这个api还算是一个比较重量组的接口,不太容易理解。 Spark中的Pregel定义如下:
def pregel[A: ClassTag](
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)(
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
}
各个参数的意义详细解释如下:
initialMsg: 初始化消息,这个初始消息会被用来初始化图中的每个节点的属性,在pregel进行调用时,会首先在图上使用mapVertices来根据initialMsg的值更新每个节点的值,至于如何更新,则由vprog参数而定,vprog函数就接收了initialMsg消息做为参数来更新对应节点的值
maxIterations: 最大迭代次数
activeDirection: 表示边的活跃方向,什么是活跃方向呢,首先要解释一下活跃消息与活跃顶点的概念,活跃节点是指在某一轮迭代中,pregel会以sendMsg和mergeMsg为参数来调用graph的aggregateMessage方法后收到消息的节点,活跃消息就是这轮迭代中所有被收成功收到的消息。这样一来,有的边的src节点是活跃节点,有的dst节点是活跃节点,而有的边两端节点都是活跃节点。如果activeDirection参数指定为“EdgeDirection.Out”,则在下一轮迭代时,只有接收消息的出边(src—>dst)才会执行sendMsg函数,也就是说,sendMsg回调函数会过滤掉”dst—>src”的edgeTriplet上下文参数
vprog: 节点变换函数,在初始时,以及每轮迭代后,pregel会根据上一轮使用的msg和这里的vprod函数在图上调用joinVertices方法变化每个收到消息的节点,注意这个函数除初始时外,都是仅在接收到消息的节点上运行,这一点可以从源码中看到,源码中用的是joinVertices(message)(vprog),因此,没有收到消息的节点在join之后就滤掉了
sendMsg: 消息发送函数,该函数的运行参数是一个代表边的上下文,pregel在调用aggregateMessages时,会将EdgeContext转换成EdgeTriplet对象(ctx.toEdgeTriplet)来使用,用户需要通过Iterator[(VertexId,A)]指定发送哪些消息,发给那些节点,发送的内容是什么,因为在一条边上可以发送多个消息,如sendToDst,如sendToSrc,所以这里是个Iterator,每一个元素是一个tuple,其中的vertexId表示要接收此消息的节点的id,它只能是该边上的srcId或dstId,而A就是要发送的内容,因此如果是需要由src发送一条消息A给dst,则有:Iterator((dstId,A)),如果什么消息也不发送,则可以返回一个空的Iterator:Iterator.empty
mergeMsg: 邻居节点收到多条消息时的合并逻辑,注意它区别于vprog函数,mergeMsg仅能合并消息内容,但合并后并不会更新到节点中去,而vprog函数可以根据收到的消息(就是mergeMsg产生的结果)更新节点属性。
代码示例:
最短路径实现
package BooksCode
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.graphstream.graph.implementations.{AbstractEdge, SingleGraph, SingleNode}
object ShortestPath_Pregel {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val conf = new SparkConf().setAppName("ShortestPath_Pregel").setMaster("local")
val sc = new SparkContext(conf)
val graph:Graph[Long,Double] = GraphGenerators.logNormalGraph(sc,numVertices = 10).mapEdges(e =>e.attr.toDouble)
val sourceId = 5L
val initialGraph = graph.mapVertices((id,_)=>if(id==sourceId) 0.0 else Double.PositiveInfinity)
// println(initialGraph.vertices.collect.mkString("\n"))
println(initialGraph.edges.distinct().collect.mkString("\n"))
println("################################################")
val sssp = initialGraph.pregel(Double.PositiveInfinity)(
// verte program
(id,dist,newDisst) =>{
println((id,dist,newDisst))
math.min(dist,newDisst)} ,
//Send Message
triplet => {if (triplet.srcAttr + triplet.attr < triplet.dstAttr){
Iterator((triplet.dstId,triplet.srcAttr + triplet.attr))
}
else {
Iterator.empty
}
},
(a,b) => math.min(a,b) //Merge message
)
//println(sssp.vertices.collect.mkString("\n"))
}
}