使用图算法实现地铁线路规划

1、概述

最近在知乎上看到一个利用Dijkstra算法进行地铁线路规划的贴子,其思路让人受益非浅,在此也感谢作者的分享 (原贴链接
将算法运用到日常中经常接触到的事物上面,想必对知识的会有更深的理解。 本文在参考上文的基础上,加入自已的理解,利用python语言实现对武汉地铁的线路规划并记录。

2、地铁数据处理

2.1、地铁数据的获取

对地铁线路进行建模,首先要获取地铁站点数据。
上文知乎作者采用的是武汉本地宝的数据,利用BeautifulSoup库对请求的html内容进行解析得到。 本文利用高德地图平台获取。
打开高德地图后,选择地铁视图,页面上便会显示国内各个城市的地铁线路图。 在Chrome浏览器中用F12打开开发者视图后,再切换到武汉地铁选项卡,右侧network选项卡中会有请求武汉站点数据的http请求。复制一下就ok了。 链接如下:

http://map.amap.com/service/subway?_1608683463948&srhdata=4201_drw_wuhan.json

通过链接可以看到,返回的是json数据。 此种方式可以方便快捷的获取地铁的站点数据。 通过分析站点json数据,可以看到其数据中还包含了经纬度等信息,从而避免了通过站点名称获取站点经纬度信息的过程。
具体获取代码如下:

def getSubwayStationInfo():
    '''
    通过高德获取武汉地铁站点数据
    '''
    # 使用request库进行请求站点数据
    metroUrl = "http://map.amap.com/service/subway?_1608683463948&srhdata=4201_drw_wuhan.json"
    resp = requests.get(metroUrl)

    if resp.status_code != 200:
        print("请求武汉地铁站点数据失败,请重试。")
        print(resp.status_code)
        print(resp.text)
        return list()
        
	# 利用json模块将文本转化为dict对象
    metroInfo = json.loads(resp.text)   # 注意,json还有load方法, 使用报错。需使用loads方法。 那么loads与load方法的区别在哪里呢?

	# 对比json数据,将其中有用的数据取出来
    allStationInfo = list()
    for m in metroInfo['l']:
        # m['st']是站名列表
        for station in m['st']:
            line = list()
            # 几号线
            line.append(m['kn'])
            # 名称    
            line.append(station['n'])
            # 经纬度
            lo, la = station['sl'].split(",")
            # 经度
            line.append(la)
            # 纬度
            line.append(lo)

            # 每个站点信息一行
            allStationInfo.append(line)
    
    return allStationInfo
2.2、地铁站点数据的保存

由于站点信息更新频率相对会很慢,故可以尝试将站点数据保存下来,避免每次进行网络请求数据。本文利用xlwt和xlrd来读写excel进行数据的写入和读取。

具体代码如下:

import xlwt     #  写excel模块
import xlrd     #  读excel模块
def data2Excel(colName, contentList, fileName="test.xls"):
    '''
    将内容写入excel文件
    fileName 文件名
    conName 列名
    contentList, 每行均为一个list
    '''
    # 创建Excel工作薄
    myWorkbook = xlwt.Workbook()
    # 添加sheet表格
    mySheet = myWorkbook.add_sheet("sheet1")

    # 写入表头信息
    for i in range(len(colName)):
        mySheet.write(0, i, colName[i])

    # 数据行从索引1开始(表头占一行)
    rowIndex = 1
    # print(contentList)
    for d in contentList:
        # 将每一行信息的列数据写入对应的单元格
        for i in range(len(d)):
            mySheet.write(rowIndex, i, str(d[i]))

        rowIndex += 1

    # 保存
    myWorkbook.save(fileName)

def readExcel(filePath, readHead=False):
    myWorkbook = xlrd.open_workbook(filePath)
    # 根据sheet索引找到sheet
    defaultSheet = myWorkbook.sheet_by_index(0)

    # 行数和列数
    rows = defaultSheet.nrows
    cols = defaultSheet.ncols

    rowStart = 0 if readHead else 1

    res = list()
    # 第一行是表头,去掉
    for i in range(rowStart, rows):
        rowList = defaultSheet.row_values(i)
        res.append(rowList)

    return res

部份保存数据如图
站点数据

3、构建图模型

在获取到了地铁站点图后,则可以对数据进行构建图模型。为了方便计算,此处采用邻接表的方式,并且在邻接表中带上距离信息。
代码如下:

def getGraph():
    '''
    通过地铁站信息来构建地铁图
    allStationInfo 为 getSubwayStationInfo 方法返回的结果。 
    '''
    # 如果已经存在图数据,则取已有的。地铁站点较多,此方式能加快线路规划速度。 
    pickleFile = os.path.join(os.getcwd(), "metro", "metro_graph.pkl")
    if os.path.exists(pickleFile):
        with open(pickleFile, 'rb') as f:
        	# 如果存在,则直接加载图对象,类型为dict。 
            graph = pickle.load(f)
            return graph

    allStationInfo = getSubwayStationInfo()
    # 当通过某个key获取相应value时,如果不存在key,则返回一个默认的空的dict对象,避免报错及对key做相应的判断等。 
    graph = defaultdict(dict)
    for i in range(len(allStationInfo) - 1):
    	# 每次处理相邻的两个站点,如i为0时,处理 0,1两个站点。 i为1时处理 1和2两个站点,如此类推。 
        j = i + 1
        # 同一条线, 不是同一条线的不处理。 (不同线路的换乘站点,可通过key在dict中取出,从面添加到相应点的邻接表中)
        if allStationInfo[i][0] == allStationInfo[j][0]:
            aLoc = (allStationInfo[i][2:])
            bLoc = (allStationInfo[j][2:])
            # 通过经纬度计算两个站点间的距离
            dis = calcDis(aLoc, bLoc)

            aName = allStationInfo[i][1]
            bName = allStationInfo[j][1]

            # 相同站点的站名相同, 故只用处理同一条线的上的节点即可。 
            # 相当于是邻接表的方式构建图, 只不过邻接表中用的是list,而此处用的是dict.
            graph[aName][bName] = dis
            graph[bName][aName] = dis

    # pprint(graph)
    with open('metro_graph.pkl','wb') as f:
        # 保存到文件
        pickle.dump(graph, f)

    return graph

最终构建出来的图如下:
站点graph
从上图中可以看到与汉口北相邻的站点为滠口新城,与黄浦路相邻的站点则有三阳路,头道街,徐家棚,赵家条。
有了图模型后,我们便可以利用相应的算法来规划路径了。

2、BFS方法

我们熟知,bfs经常用来搜索最短路径,但这里的最短路径指的经过的顶点数最少。 此处,如果我们也只考虑经过最少的站点数最少,那么可以使用bfs来规划最短路径。
代码如下:

def getBFSPath(src, target):
    '''
    从src出发,到达目地的target的最短路径。 (这里的最短路径指的是通过的站点数最少)
    '''
    graph = getGraph()

    queue = list()
    visited = list()
    # 用来输出相应的路径
    edgeTo = dict()
    queue.append(src)
    visited.append(src)

    while queue:
        node = queue.pop(0)   # 弹出队首元素, 这里需要注意的是pop()是弹出队尾元素。
        # 所有相邻的点, adjNode 是dict类型,数据如 {'广埠屯': 1.1260840952908053, '宝通寺': 1.251311885982076}
        adjNode = graph[node]

        # 一次性遍历邻接的所有节点
        for n in adjNode:
            if not n in visited:
                visited.append(n)
                edgeTo[n] = node   # 从node节点可以到n节点,在搜索到结果后可以通过终点倒推到起点
                queue.append(n)

                # 如果找到可以停止
                if n == target:
                    # 同时退出while循环(下面的break只能退出for循环)
                    queue.clear()
                    break

	# 输出路径
    path = list()
    dest = target
    while dest != src:
        path.append(dest)
        dest = edgeTo[dest]
    path.append(src)
    # 这样计算出来的路径是倒过来的,也即从目的地倒推到出发点。 
    path.reverse()

    print("从 【{}】 到 【{}】 的最短路径(经过的站点数最少)如下:(使用广度优先遍历)".format(src, target))
    print(path)

输入起点螃蟹岬,终点梅苑小区测试如下:
在这里插入图片描述
对比地图查看,使用如上线路走过的站点数确实最少。

3、DFS方法

dfs方法一般用于搜索点是否可达,此处可用dfs搜索输出所有的路径,以及通过对比输出站点数最少路径。

3.1 、输出所有可行路径
def getAllPathByDFS(src, target):
    '''
    通过dfs获取所有可行的路径
    '''
    MAX_DIS = 10000000
    graph = getGraph()

    # 建立站名与数字索引的映射
    # 名称 -> 索引
    nameIndexGraph = dict()
    index = 0
    for k in graph:
        nameIndexGraph[k] = index
        index += 1

    # 索引 -> 名称
    indexNameGraph = dict()
    index = 0
    for k in graph:
        indexNameGraph[index] = k
        index += 1
    # 以上建立一个双向索引, 主要是方便通过索引计算路径,然后通过索引找到名称输出

    # 节点个数
    nodeNum = len(graph)
    # 节点是否已访问
    visited = [False for x in range(nodeNum)]
    path = list()
    minPathLen = MAX_DIS
    ans = list()

    def dfs(node):
        '''
        node是字符类, 表示站名
        '''
        if node == target:
        	# 找到一个可行路径
            print(path)
            return 

        nodeIndex = nameIndexGraph[node]
        adjNode = graph[node]
        for n in adjNode:
            nIndex = nameIndexGraph[n]
            if not visited[nIndex]:
                visited[nIndex] = True
                path.append(n)
                dfs(n)
                # 撤销访问
                path.remove(n)
                visited[nIndex] = False 


    dfs(src)
    # 由于地铁中有回路,故使用dfs搜索所有的路径会有大量的结果,此方法也较耗时。 
3.2、输出最短路径(站点数最少)

由于上面的输出所有的路径在实际中并不常用,故对上面的算法改进,只求经过最少站点路径。代码如下:

def getDFSShortestPath(src, target):
    '''
    通过dfs获取所有最短的路径, 本质是搜索所有的路径,然后取最短的。 
    '''
    MAX_DIS = 10000000
    graph = getGraph()

    # 建立站名与数字索引的映射
    # 名称 -> 索引
    nameIndexGraph = dict()
    index = 0
    for k in graph:
        nameIndexGraph[k] = index
        index += 1

    # 索引 -> 名称
    indexNameGraph = dict()
    index = 0
    for k in graph:
        indexNameGraph[index] = k
        index += 1
    # 以上建立一个双向索引

    # 节点个数
    nodeNum = len(graph)
    # 节点是否已访问
    visited = [False for x in range(nodeNum)]
    path = list()
    minPathLen = MAX_DIS
    ans = list()

    def dfs(node):
        '''
        node是字符类, 表示站名
        '''
        nonlocal minPathLen
        nonlocal ans
        if node == target and len(path) < minPathLen:
            minPathLen = len(path)
            # 返回一个新列表,直接赋值无法传到外层
            ans = list(path)
            return 

        nodeIndex = nameIndexGraph[node]
        adjNode = graph[node]
        for n in adjNode:
            nIndex = nameIndexGraph[n]
            if not visited[nIndex]:
                visited[nIndex] = True
                path.append(n)
                dfs(n)
                # 撤销访问
                path.remove(n)
                visited[nIndex] = False 


    dfs(src)
    ans.insert(0, src)
    print("从 【{}】 到 【{}】 的最短路径(经过的站点数最少)如下:(使用深度优先遍历)".format(src, target))
    print(ans)

此方法本质还是对所有搜索结果进行筛选,并没有实质的改进。故依然很耗时。
result
通过对比发现,依然可以得到最短的路径。

4、带权图的路径规划

以上练习只考虑了站点数,而没有将实际的距离考虑在内。 如果将站点间的距离考虑在内的话,则会将问题抽象为求带权无向图的最短路径问题。 此类问题的解法一般有弗洛伊德算法和Dijkstra算法。

4.1、弗洛伊德算法

关于弗洛伊德算法,主要思想是从点 i 到 j,有没有一个点 k 使得 i 到 j 的距离更短,然后不断的穷举点k,从而得到点 i 到 点 j 的最短距离。 其次,弗洛伊德算法不能计算有负权回路的最短路径(如果有负权回路的话,每走一次,距离会减小,故永远找不到最小)。
使用弗洛伊德算法计算代码如下:

def getFloydPath(src, target):
    '''
    弗洛伊德求多源最短路径
    '''
    MAX_DIS = 10000000
    graph = getGraph()
    # pprint(graph)
    nodeNum = len(graph)

    # 建立站名与数字索引的映射
    # 名称 -> 索引
    nameIndexGraph = dict()
    index = 0
    for k in graph:
        nameIndexGraph[k] = index
        index += 1

    # 索引 -> 名称
    indexNameGraph = dict()
    index = 0
    for k in graph:
        indexNameGraph[index] = k
        index += 1
    # 以上建立一个双向索引

    # 距离列表 及 初始化
    # 不同站点的距离初始化为最大,相同的站点为0
    dis = [[MAX_DIS if x != y else 0 for x in range(nodeNum)] for y in range(nodeNum)]
    # 根据图读入已知数据
    for k in graph:
        stationIndex = nameIndexGraph[k]
        for m in graph[k]:
            tmpIndex = nameIndexGraph[m]
            # 双向距离是一样的
            dis[stationIndex][tmpIndex] = graph[k][m]
            dis[tmpIndex][stationIndex] = graph[k][m]

    # 存储路径
    edgeTo = dict()
    path = [[-1 for x in range(nodeNum)] for y in range(nodeNum)]
    for i in range(nodeNum):
        for j in range(nodeNum):
            # 意义:从点i到j的最短路径要经过的点为i (初始化为j也是可以的,但是后面的赋值及路径就要修改一下了)
            path[i][j] = i


    # 弗洛伊德算法核心, 中间枚举点k要放在最外层。切记。
    # 从 i 点到 j 点,经过 k 点最近
    for k in range(nodeNum):
        for i in range(nodeNum):
            for j in range(nodeNum):
                # print("======={}, {}".format(dis[i][j], dis[i][k] + dis[k][j]))
                if dis[i][j] > (dis[i][k] + dis[k][j]):
                    dis[i][j] = dis[i][k] + dis[k][j]
                    path[i][j] = path[k][j]
                    
    
    srcStationIndex = nameIndexGraph[src]
    targetStationIndex = nameIndexGraph[target]
    minLenPath = dis[srcStationIndex][targetStationIndex]
    pathIndex = list()
    while targetStationIndex != srcStationIndex:
        pathIndex.append(targetStationIndex)
        targetStationIndex = path[srcStationIndex][targetStationIndex]

    pathIndex.append(srcStationIndex)
    pathIndex.reverse()

    pathName = [indexNameGraph[i] for i in pathIndex]
    print("从 【{}】 到 【{}】 的最短路径(站间距总和最小)为 {}, 路径如下:(使用Floyd算法)".format(src, target, minLenPath))
    print(pathName)

运行结果如下:
floyd算法result
此处计算的距离之和均为两个站点的直线距离,而不是地铁线路走过的实际距离。
在知乎上看到一个笑话,为啥Dijkstra发明不了弗洛伊德算法,因为Dijkstra名字中是ijk,而不是kij。通过这个可以辅助记忆k是最外层循环。
弗洛伊德算法思想比较简单,但是由于用了三层循环,时间复杂度比较高。在顶点比较多的情况下,则可以使用下面的Dijkstra算法。

4.2、Dijkstra算法

Dijkstra算法可以求解单源最短路径,即只能求某个点到其他各点的最短距离,而不能像弗洛伊德算法一样,求任意两点间的最短距离。
关于Dijkstra算法的解释参与此篇博客,写的非常好。
通过Dijkstra算法计算线路的代码如下:

def getDijkstraPath(src, target):
    '''
    dijkstra算法求单源最短路径, 参考 https://blog.csdn.net/heroacool/article/details/51014824
    '''
    MAX_DIS = 10000000
    graph = getGraph()
    nodeNum = len(graph)

    # 建立站名与数字索引的映射
    # 名称 -> 索引
    nameIndexGraph = dict()
    index = 0
    for k in graph:
        nameIndexGraph[k] = index
        index += 1

    # 索引 -> 名称
    indexNameGraph = dict()
    index = 0
    for k in graph:
        indexNameGraph[index] = k
        index += 1
    # 以上建立一个双向索引

    # 起始点标号
    srcIndex = nameIndexGraph[src]
    # 起始点到其余点的最短距离是否已获取
    flag = [False for x in range(nodeNum)]
    # 起始点到其余点的距离
    dist = [MAX_DIS for x in range(nodeNum)]

    # 根据已知数据初始化
    for k in graph:
        stationIndex = nameIndexGraph[k]
        if stationIndex == srcIndex:
            for m in graph[k]:
                tmpIndex = nameIndexGraph[m]
                # 初始点到tmpIndex的距离为dist[tmpIndex]
                dist[tmpIndex] = graph[k][m]

    # 对顶点自己初始化
    flag[srcIndex] = True
    dist[srcIndex] = 0

    # 记录前驱节点,用来还原路径
    prev = [-1 for x in range(nodeNum)]

    # 外层只要nodeNum - 1次即可。 (每次求到一个点最短距离,只需nodeNum - 1次即可)
    for i in range(1, nodeNum):
        
        minDis = MAX_DIS
        for j in range(nodeNum):
            # 找最小路径
            # 即在未获取最短路径的点中,找到离起始点最近的点
            if not flag[j] and dist[j] < minDis:
                minDis = dist[j]
                k = j

        # 到顶点k的最短距离已经获取
        flag[k] = True

        for j in range(nodeNum):
            # 已经获取到 j 点的最小距离,则跳过
            if flag[j]:
                continue

            # 从 k 点到 j 点的距离
            if k != j:
                kName = indexNameGraph[k]
                jName = indexNameGraph[j]
                # k 与 j之间的距离
                if jName in graph[kName]:
                    disKJ = graph[kName][jName]

                    if (minDis + disKJ) < dist[j]:
                        dist[j] = minDis + disKJ
                        # 点 j 的前驱节点是k
                        prev[j] = k
                else:
                    # 不存在则为无穷大
                    disKJ = MAX_DIS
                

    targetIndex = nameIndexGraph[target]
    pathIndex = list()
    
    t = targetIndex
    while t != -1:
        pathIndex.append(t)
        t = prev[t]
    pathIndex.append(srcIndex)

    pathIndex.reverse()
    pathName = [indexNameGraph[i] for i in pathIndex]
    print("从 【{}】 到 【{}】 的最短路径(站间距总和最小)为 {} ,路径如下:(使用Dijkstra算法)".format(src, target, dist[targetIndex]))
    print(pathName)

计算结果如下:
在这里插入图片描述
从上图可以看出,计算结果和弗洛伊德算法一样。 但在实际运行过程中可以发现,Dijkstra算法比弗洛伊德算法快多了。
由于Dijkstra算法在计算过程中需要在剩余未访问的点中求解最短的距离,故可以使用优先队列来优化时间复杂度。

5、完整线路规划

以上算法只实现了具体的地铁站点间的规划,如果起始点是任意地点而不是地铁站,那么需要找出距离最近的地铁站。
此处采用的方法比较暴力,即先通过高德api求出起始点的经纬度,然后遍历地铁站,找出最近的地铁站,从而使用上面的算法找最优线路规划。
获取最近地铁站代码如下:

def getNearSubWay(keyword, city):
    '''
    获取当前位置最近的地铁站
    '''
    res = list()
    if len(res) == 0:
        retryCount = 3
        while retryCount > 0:
            res = getLocation(keyword, city)
            if (len(res) != 0):
                break
            retryCount -= 1
    
    print("当前位置 {} - {} 的经纬度是:{}-{}".format(city, keyword, res[0], res[1]))

    if (len(res)) == 0:
        print('重试3次了,请求失败,return')
        return ""
    else:
        # 从excel中读取站点经纬度信息,遍历查找最近的站点信息
        subwayExcelFilePath = os.path.join(os.getcwd(), "metro", "subway.xls")
        stationInfo = readExcel(subwayExcelFilePath)
        minDis = sys.maxsize
        for info in stationInfo:
            tmp = calcDis((info[2], info[3]), (res[1], res[0]))
            if tmp < minDis:
                minDis = tmp
                minStation = info

        print("距离最近的地铁站是 【{}】, 距离 {} km".format(minStation[1], minDis))
        return minStation[1]

其中,getLocation方法中使用了高德地图的api,代码如下:

def getLocation(keyword, city):
    # 从高德平台注册后获取的
    keynum = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
    #获得经纬度
    user_agent = 'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_6_8; en-us) AppleWebKit/534.50 (KHTML, like Gecko) Version/5.1 Safari/534.50'
    headers = {'User-Agent': user_agent}
    # url中的参数可以参看高德平台说明
    url = 'http://restapi.amap.com/v3/place/text?key=' + keynum + '&keywords=' + keyword + '&types=&city=' + city + '&children=1&offset=1&page=1&extensions=all'
    resp = requests.get(url, headers=headers)
    if resp.status_code == 200:
        resp.encoding='utf-8'
        resp = json.loads(resp.text)
        result = resp['pois'][0]['location'].split(',')
        return result[0], result[1]
    else:
        return list()

距离的计算使用的是geopy库的方法,如下:

def calcDis(a, b):
    '''
    计算两点间的距离,a, b均是包含经纬度的元组 (纬度,经度)
    '''
    if len(a) != 2 or len(b) != 2:
        print("请检查参数是否包含完整的经纬度信息")
        return 0
    return geodesic(a, b).km

输入测试数据后,运行结果如下:
在这里插入图片描述
得到了站点信息后,使用上面的广度搜索、深度搜索 或 弗洛伊德算法、Dijkstra算法便可以得到相应的线路规划。

6、总结

本文使用高德数据来构建图模型,从而练习图有关的算法。
首先使用了图中常用的两种搜索算法,广度优先搜索算法、深度优先搜索算法。这两者是主要区别是使用的数据结构不同,广度优先搜索使用的是队列,而深度优先搜索使用的栈。
其次,使用了两种最短路径算法。一种是多源最短路径算法(弗洛伊德算法), 此算法可以求得任意两点间的最短距离。 (此种算法要注意三层循环最外层是k,也即中间节点)。一种是单源最短路径算法(Dijkstra算法),此算法可求起始点到其余各点的最短距离。

  • 6
    点赞
  • 68
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值