ShortestPaths的源码如下:
package org.apache.spark.graphx.lib
import scala.reflect.ClassTag
import org.apache.spark.graphx._
/**
* Computes shortest paths to the given set of landmark vertices, returning a graph where each
* vertex attribute is a map containing the shortest-path distance to each reachable landmark.
*/
object ShortestPaths {
/** Stores a map from the vertex id of a landmark to the distance to that landmark. */
type SPMap = Map[VertexId, Int]
private def makeMap(x: (VertexId, Int)*) = Map(x: _*)
private def incrementMap(spmap: SPMap): SPMap = spmap.map { case (v, d) => v -> (d + 1) }
private def addMaps(spmap1: SPMap, spmap2: SPMap): SPMap =
(spmap1.keySet ++ spmap2.keySet).map {
k => k -> math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue))
}.toMap
/**
* Computes shortest paths to the given set of landmark vertices.
*
* @tparam ED the edge attribute type (not used in the computation)
*
* @param graph the graph for which to compute the shortest paths
* @param landmarks the list of landmark vertex ids. Shortest paths will be computed to each
* landmark.
*
* @return a graph where each vertex attribute is a map containing the shortest-path distance to
* each reachable landmark vertex.
*/
def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
val spGraph = graph.mapVertices { (vid, attr) =>
if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
}
val initialMessage = makeMap()
def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = {
addMaps(attr, msg)
}
def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = {
val newAttr = incrementMap(edge.dstAttr)
if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
else Iterator.empty
}
Pregel(spGraph, initialMessage)(vertexProgram, sendMessage, addMaps)
}
}
关于单源最短路径,我们可以调用 ShortestPaths .run(graph, landmarks) 得到graph中的顶点到landmarks的“距离”,但是这个“距离”只是“跳数”。换句话说,只在graph中每条边的权重都为1的情况下,才能保证结果的正确性。而现实情况中,往往都不满足这个条件。那么问题来了,我们该如何做呢?学过图论的朋友都知道,Dijkstra算法可以解决这个问题。遗憾的是,GraphX目前(Spark2.0.2)并未提供这样的API,所以基于GraphX实现Dijkstra算法变得很有必要。
Dijkstra(单源最短路径)
//单源最短路径
def dijkstra[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = {
//初始化,其中属性为(boolean, double,Long)类型,boolean用于标记是否访问过,double为顶点距离原点的距离,Long是上一个顶点的id
var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L))
for(i <- 1L to g.vertices.count()) {
//从没有访问过的顶点中找出距离原点最近的点
val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1
//更新currentVertexId邻接顶点的‘double’值
val newDistances = g2.aggregateMessages[(Double, Long)](
triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) { //只给未确定的顶点发送消息
triplet.sendToDst((triplet.srcAttr._2 + triplet.attr, triplet.srcId))
},
(x, y) => if(x._1 < y._1) x else y ,
TripletFields.All
)
//newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
//更新图形
g2 = g2.outerJoinVertices(newDistances) {
case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 )
case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3)
}
//g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
}
//g2
g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) )
}
Prime(最小生成树)
知道Dijkstra算法的人也一定知道Prime算法。
//最小生成树
def prime[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = {
//初始化,其中属性为(boolean, double,Long)类型,boolean用于标记是否访问过,double为加入当前顶点的代价,Long是上一个顶点的id
var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L))
for(i <- 1L to g.vertices.count()) {
//从没有访问过的顶点中找出 代价最小 的点
val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1
//更新currentVertexId邻接顶点的‘double’值
val newDistances = g2.aggregateMessages[(Double, Long)](
triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) { //只给未确定的顶点发送消息
triplet.sendToDst((triplet.attr, triplet.srcId))
},
(x, y) => if(x._1 < y._1) x else y ,
TripletFields.All
)
//newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
//更新图形
g2 = g2.outerJoinVertices(newDistances) {
case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 )
case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3)
}
//g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
}
//g2
g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) )
}
FloydWarshall(多源最短路径)
//多源最短路径
def floydWarshall[VD: ClassTag](g: Graph[VD, Double]) = {
def mergeMaps(a: Map[VertexId, Double], b: Map[VertexId, Double]) = {
(a.keySet ++ b.keySet).map{ k => (k, math.min(a.getOrElse(k, Double.MaxValue), b.getOrElse(k, Double.MaxValue))) }.toMap
}
val N = g.vertices.count() //图顶点的个数
var n = -1
//初始化图
var g2 = g.mapVertices( (vid, _) => Map(vid -> 0.0) )
//当n = N*N时,退出循环。注:不难发现最终结果是一个实对称矩阵
while(n < N * N) {
val newVertices = g2.aggregateMessages[Map[VertexId, Double]](
triplet =>{
val dstPlus = triplet.dstAttr.map{ case (vid, distance) => (vid, triplet.attr+distance) }
if(dstPlus != triplet.srcAttr) { triplet.sendToSrc(dstPlus) }
},
(a, b) => mergeMaps(a, b) ,
TripletFields.Dst
)
g2 = g2.outerJoinVertices(newVertices)( (_, oldAttr, opt) => mergeMaps(oldAttr, opt.get) )
n = g2.vertices.map{ case (vid, srcAttr) => srcAttr.size }.reduce(_ + _)
//println("number\t" + n)
}
g2
}
纸上得来终觉浅,绝知此事要躬行。下面开始实战、实战、实战,重要的事情说三遍!!!
val myVertices = sc.makeRDD(Array((1L, "A"), (2L, "B"), (3L, "C"), (4L, "D"), (5L, "E"), (6L, "F"), (7L, "G")))
val initialEdges = sc.makeRDD(Array(Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0),
Edge(2L, 3L, 8.0), Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0),
Edge(3L, 5L, 5.0),
Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0),
Edge(5L, 6L, 8.0), Edge(5L, 7L, 9.0),
Edge(6L, 7L, 11.0)))
val myEdges = initialEdges.filter(e => e.srcId != e.dstId).flatMap(e => Array(e, Edge(e.dstId, e.srcId, e.attr))).distinct() //去掉自循环边,有向图变为无向图,去除重复边
val myGraph = Graph(myVertices, myEdges).cache()
println(ShortestPaths.run(myGraph, Seq(3)).vertices.collect().mkString(","))
println(dijkstra(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | "))
println(prime(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | "))
floydWarshall(myGraph).vertices.foreach(println)
输出依次如下:
ShortestPaths:
(1,Map(3 -> 2)) | (2,Map(3 -> 1)) | (3,Map(3 -> 0)) | (4,Map(3 -> 2)) | (5,Map(3 -> 1)) | (6,Map(3 -> 2)) | (7,Map(3 -> 2))
Dijkstra:
(1,(A,15.0,2)) | (2,(B,8.0,3)) | (3,(C,0.0,-1)) | (4,(D,17.0,2)) | (5,(E,5.0,3)) | (6,(F,13.0,5)) | (7,(G,14.0,5))
Prime:
(1,(A,7.0,2)) | (2,(B,7.0,5)) | (3,(C,0.0,-1)) | (4,(D,5.0,1)) | (5,(E,5.0,3)) | (6,(F,6.0,4)) | (7,(G,9.0,5))
FloydWarshall:
(4,Map(5 -> 14.0, 1 -> 5.0, 6 -> 6.0, 2 -> 9.0, 7 -> 17.0, 3 -> 17.0, 4 -> 0.0))
(2,Map(5 -> 7.0, 1 -> 7.0, 6 -> 15.0, 2 -> 0.0, 7 -> 16.0, 3 -> 8.0, 4 -> 9.0))
(7,Map(5 -> 9.0, 1 -> 22.0, 6 -> 11.0, 2 -> 16.0, 7 -> 0.0, 3 -> 14.0, 4 -> 17.0))
(5,Map(5 -> 0.0, 1 -> 14.0, 6 -> 8.0, 2 -> 7.0, 7 -> 9.0, 3 -> 5.0, 4 -> 14.0))
(3,Map(5 -> 5.0, 1 -> 15.0, 6 -> 13.0, 2 -> 8.0, 7 -> 14.0, 3 -> 0.0, 4 -> 17.0))
(1,Map(5 -> 14.0, 1 -> 0.0, 6 -> 11.0, 2 -> 7.0, 7 -> 22.0, 3 -> 15.0, 4 -> 5.0))
(6,Map(5 -> 8.0, 1 -> 11.0, 6 -> 0.0, 2 -> 15.0, 7 -> 11.0, 3 -> 13.0, 4 -> 6.0))