Dijkstra算法是经典的用来求解最短通路的算法,笔者刚好在离散数学中学到了这个知识点,顺着思路写了一下代码,供大家参考借鉴
首先是把“路”定义出来,这里我用了一个class类
class Road():
def __init__(self,val=0,pre=0,behind=0):
self.val=val
self.pre=pre
self.behind=behind
pre是路的前端,behind是路的后端,val是路的权
这样我们就可以把输入都转化为路了:
input_=input()
list=input_.split(' ')
n,m,s,t=int(list[0]),int(list[1]),int(list[2]),int(list[3])#第一行输入的是节点数,边数,开始和终点
node_list=[] #所有路的集合
for i in range(m):
word=input()
list=word.split(' ')
node=Road()
node.pre,node.behind,node.val=int(list[0]),int(list[1]),int(list[2])
node_list.append(node)
然后我们准备开始寻找最短路径
这里要准备一个stack,用于存放目前为止能够一步走的所有的路的集合
例如在这个图中,我们的stack就用来存放v-a,v-b,v-c这三条路,这样的话我们就可以直接找到一步走的最短路径:
stack = []
for item in node_list:
if item.pre == s:
stack.append(item)
然后进入找最短通路的部分
根据Dijkstra算法,每次只需要寻找能够一步走的最短路径,也就是我们的stack里面的最短路径,然后把走到的节点和原来的节点看作一个节点继续寻找,直到走到终点。
def findroad(stack,node_list):
min_= stack[0].val
behind_= stack[0].behind
node_= stack[0]
for item in stack:
if min_ > item.val:
min_ = item.val #找到最短路
behind_ = item.behind
node_=item
stack.remove(node_)
road.append(node_) #加入到走过的路的集合中
new_stack = stack
if behind_ == t:
return min_
for item in node_list:
if item.pre==behind_:
tmp_node = Road()
tmp_node.val = item.val + min_
tmp_node.pre = item.pre
tmp_node.behind = item.behind
new_stack.append(tmp_node) #加入到最新的一步走里面
return findroad(new_stack,node_list)
每当找到最短路,就把它从我们的stack里删除,并加入到我们所有已经走过的路的集合中(road),同时,在所有的node_list中寻找新的一步路径,也就是从我们新走到的这个节点后还能够一步走到哪里,加入stack中,二者合并为new_stack传入下一个函数迭代,注意新的节点需要把权改变,也就是val+min_。
上图是我们迭代一次后的结果,可以看到每次我们的stack都是能够一步走的路,并且road记录了我们所有走过的路。
当迭代结束,我们返回最终的val,也就是最短路径
然后我们只需要在road里倒推回起点就可以找到我们经过的最短路径的节点了:
road=[]
min_ = findroad(stack,node_list)
res = t
ans = [t]
while res!= s:
for item in road:
if item.behind == res:
ans.append(item.pre)
res = item.pre
break
然后我们把结果倒序输出就完成了
for item in ans[::-1]:
print(item,end=' ')
print('')
print(min_)
最终代码:
class Road():
def __init__(self,val=0,pre=0,behind=0,flag=0):
self.val=val
self.pre=pre
self.behind=behind
self.flag=flag
def findroad(stack,node_list):
min_= stack[0].val
behind_= stack[0].behind
node_= stack[0]
for item in stack:
if min_ > item.val:
min_ = item.val #找到最短路
behind_ = item.behind
node_=item
stack.remove(node_)
road.append(node_)
new_stack = stack
if behind_ == t:
return min_
for item in node_list:
if item.pre==behind_:
tmp_node = Road()
tmp_node.val = item.val + min_
tmp_node.pre = item.pre
tmp_node.behind = item.behind
new_stack.append(tmp_node) #加入到最新的一步走里面
return findroad(new_stack,node_list)
input_=input()
list=input_.split(' ')
n,m,s,t=int(list[0]),int(list[1]),int(list[2]),int(list[3])#第一行输入的是节点数,边数,开始和终点
node_list=[] #所有路的集合
for i in range(m):
word=input()
list=word.split(' ')
node=Road()
node.pre,node.behind,node.val=int(list[0]),int(list[1]),int(list[2])
node_list.append(node)
stack = []
for item in node_list:
if item.pre == s:
stack.append(item)
road=[]
min_ = findroad(stack,node_list)
res = t
ans = [t]
while res!= s:
for item in road:
if item.behind == res:
ans.append(item.pre)
res = item.pre
break
for item in ans[::-1]:
print(item,end=' ')
print('')
print(min_)
代码略显冗杂,也请大家多多批评指正!