之前文章能够求出dijkstra最短距离,但是速度太慢,原因之一是算法把整个二维数组的点到起始点的最短距离都求出来了,但其实没有必要,我们只需要起始点到结束点的最短路径即可。
优化思路:dijkstra算法执行时,若当前点的坐标等于结束点坐标时就停止算法,这样可以避免寻找其他无用点而浪费时间
程序59,60,61行代码可自行修改
#启发函数中的归一化需要根据具体情况修改
import numpy as np
import time
import random
import math
def shuzu(w,h):
nnum = [[random.randint(0, 10) for y in range(h)] for x in range(w)]
nnum = np.array(nnum) / 10
return nnum
def get_neighbors(p):
x,y=p#当前点坐标
#3*3领域范围
x_left=0 if x==0 else x-1
x_right=w if x==w-1 else x+2
y_top=0 if y==0 else y-1
y_bottom= h if y==h-1 else y+2
return [(x,y) for x in range(x_left,x_right) for y in range(y_top,y_bottom)]#范围3*3领域9个点坐标
def neight_cost(p,next_p):
return abs(nnum[next_p[0]][next_p[1]]-nnum[p[0]][p[1]])
def a_star(nnum, seed,end):
process = set() # 已处理点的集合,集合不能重复
cost = {seed: 0.0} # 当前点路径积累的成本代价值
path = {} # 路径
while cost: # cost为空代表所有点都处理了,每个点处理了其对应的cost值会被删掉
p = min(cost, key=cost.get) # 每次取出当前成本代价最小值
neighbors = get_neighbors(p) # 当前成本代价最小值的领域节点
process.add(p) # 保存已处理过的点
for next_p in [x for x in neighbors if x not in process]: # 没有被处理过的领域点坐标
dik_cost= neight_cost(p, next_p) + cost[p]# 当前点与领域的点cost的差值 + 起始点到到当前点累计的cost值
if next_p in cost: # 如果该领域点之前计算过了,则需要判断此时所用的代价小还是之前的代价小,如果现在的代价小则需要更新
if dik_cost < cost[next_p]: # 小的话,把之前记录的代价值去除掉。为了之后的更新
cost.pop(next_p)
else: # 该领域点之前没有计算过 或者 需要更新
cost[next_p] = dik_cost # 该领域所需代价值的更新
process.add(next_p) # 添加到已处理过的点
path[next_p] = p # 把cost最小点作为领域点next_p的前一个点
if (next_p==end):#当前点到达结束点时,提前结束
cost={}#为了跳出循环
cost[p]=0#为了跳出循环
break#为了跳出循环
cost.pop(p) # 已经处理了的点就排除
return path
def small_path_point(seed,end,paths):
path_piont=[]
path_piont.insert(0,end)#把结束点加到路径中
while seed!=end:#直到结束点坐标等于开始点坐标是结束
top_point=paths[end]#更新的top_point为最短路径中某个点的上一个坐标点,即更加靠近种子点
path_piont.append(top_point)#记录路径
end=top_point#更新点坐标
return path_piont
if __name__=='__main__':
start = time.time()
global nnum,w,h
nnum=shuzu(300,400)#创建二维数组
seed=(0,0)#起始点
end=(200,200)#结束点
h=nnum.shape[1]#高
w = nnum.shape[0]#宽
print('地图\n',nnum)#显示地图
print('起始点',seed)
paths = a_star(nnum, seed,end)
print('dijkstra所有路径', paths)
path_piont = small_path_point(seed, end, paths) # 开始点到结束点的最短路径所经过的坐标点
print('起始点:', seed, '到结束点:', end, '的dijkstra最短路径为:', path_piont)
print('一共走了%d步' % len(path_piont))
all_leng = 0
for i in range(len(path_piont) - 1):
leng = nnum[path_piont[i]] + nnum[path_piont[i + 1]]
all_leng = leng + all_leng
print('权重:', all_leng)
end = time.time()
print('总共耗时:', end - start)