[python刷题模板] 最短路(Dijkstra/SPFA/Johnson
一、 算法&数据结构
1. 描述
本文打一个最短路模板。
Dijkstra是nlogn+m处理无负权边情况
SPFA用n*m处理带负权边情况。
Johnson用n*mlogn处理全源最短路。(优于floyd)
2. 复杂度分析
- Dijkstra堆优化O(nlogn)。
- spfa O(nm)。
- Johnson O(mNlgn)
3. 常见应用
- dijkstra查询单源路径最短路,注意不能有负权边。
- spfa可以用来判断是否存在负环。
4. 常用优化
- 优先队列(堆)优化。这里注意没写vis已访问字典,由于最近的一定先访问,因此可以用dist>来代替。
二、 模板代码
1. Dijkstra洛谷模板题P4779 ,存在重边的单元最短路。
- 写了两个模板,一个用字典形式存图,目的是处理节点非编号的题目(比如矩阵搜索的节点用坐标表示)
- 字典形式被重边折磨了,会覆盖,因此要注意处理。
- 还有注意本题是有向图。
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')
class Dijkstra:
"""
堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
其中INF可以自己定义,用来表示无法到达,默认是inf
"""
def __init__(self, g, start, n=None, INF=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
self.INF = INF if INF is not None else inf
def dist_by_list_g_0_indexed(self):
"""
基于g是从0~n-1表示节点的图
:return: 距离数组代表start点到每个点的最短路
"""
dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = [(0, self.start)] # 优先队列
while q:
c, u = heapq.heappop(q) # 当前点的最短路
if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
heapq.heappush(q, (d, v))
return dis # 距离数组
def dist_by_dict_g(self):
"""
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
"""
dis, g = {self.start: 0}, self.g # 初始化距离数组
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis.get(u, inf): continue
for v, w in g[u].items():
d = c + w
if d < dis.get(v, inf):
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist_by_default_dict_g(self):
"""优先用defaultdict版本,卡性能再考虑这个
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
"""
dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典
dis[self.start] = 0
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis[u]: continue
for v, w in g[u].items():
d = c + w
if d < dis[v]:
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist(self):
"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""
if isinstance(self.g, list):
return self.dist_by_list_g_0_indexed()
return self.dist_by_default_dict_g()
def main():
n, m, s = RI()
g = collections.defaultdict(dict)
for _ in range(m):
u, v, w = RI()
a, b = u - 1, v - 1
if b in g[a]: # 在这wa了很多次,字典形式的图要处理重边
g[a][b] = min(g[a][b], w)
else:
g[a][b] = w
dis = Dijkstra(g, s - 1).dist()
ans = [0] * n
for i in range(n):
ans[i] = dis[i]
print(*ans)
# def main():
# n, m, s = RI()
# g = [[] for _ in range(n)]
# for _ in range(m):
# u, v, w = RI()
# g[u - 1].append((v - 1, w))
#
# dis = Dijkstra(g, s - 1).dist()
# print(*dis)
if __name__ == '__main__':
# testcase 2个字段分别是input和output
test_cases = (
(
"""4 6 1
1 2 2
2 3 2
2 4 1
1 3 5
3 4 3
1 4 4""", """0 2 4 3
"""
),
)
if os.path.exists('test.test'):
total_result = 'ok!'
for i, (in_data, result) in enumerate(test_cases):
result = result.strip()
with io.StringIO(in_data.strip()) as buf_in:
RI = lambda: map(int, buf_in.readline().split())
RS = lambda: buf_in.readline().strip().split()
with io.StringIO() as buf_out, redirect_stdout(buf_out):
main()
output = buf_out.getvalue().strip()
if output == result:
print(f'case{i}, result={result}, output={output}, ---ok!')
else:
print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
total_result = '---WA!---WA!---WA!'
print('\n', total_result)
else:
main()
2. 转化为单源最短路
链接: 882. 细分图中的可到达节点
- 把cnt+1看做边权,建图
- 先求出每个点的最短路,然后ans累计在maxMoves步数下能到达的节点数,这样就把原式节点算完了。
- 然后处理每条边上的cnt,对于(u,v,cnt):
- 如果u当前剩余步数a = maxMoves-dist[u]>0,则可以从u向v访问a个节点,当然不能超过cnt
- 同理处理v,如果v当前剩余步数b = maxMoves-dist[v]>0,则可以从v向u访问b个节点,当然不能超过cnt。
- 最后,a+b也不可以超过cnt,即对于这条边:ans+=max(a+b,cnt)
class Dijkstra:
"""
堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
"""
def __init__(self, g, start, n=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
def dist_by_list_g_0_indexed(self):
"""
基于g是从0~n-1表示节点的图
:return: 距离数组代表start点到每个点的最短路
"""
dis, g = [inf] * self.n, self.g # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = [(0, self.start)] # 优先队列
while q:
c, u = heapq.heappop(q) # 当前点的最短路
if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
heapq.heappush(q, (d, v))
return dis # 距离数组
def dist_by_dict_g(self):
"""
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
"""
dis, g = {self.start: 0}, self.g # 初始化距离数组
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis.get(u, inf): continue
for v, w in g[u].items():
d = c + w
if d < dis.get(v, inf):
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist_by_default_dict_g(self):
"""优先用defaultdict版本,卡性能再考虑这个
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
"""
dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典
dis[self.start] = 0
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis[u]: continue
for v, w in g[u].items():
d = c + w
if d < dis[v]:
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist(self):
"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""
if isinstance(self.g, list):
return self.dist_by_list_g_0_indexed()
return self.dist_by_default_dict_g()
class Solution:
def reachableNodes(self, edges: List[List[int]], maxMoves: int, n: int) -> int:
g = [[] for _ in range(n)]
for u,v,cnt in edges:
g[u].append((v,cnt+1))
g[v].append((u,cnt+1))
dist = Dijkstra(g,0).dist()
ans = sum(dist[x]<= maxMoves for x in range(n))
for u, v, cnt in edges:
a = max(maxMoves - dist[u],0)
b = max(maxMoves - dist[v],0)
ans += min(cnt,a+b)
return ans
3.矩阵搜索
这题用0-1bfs更快
class Dijkstra:
"""
堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
"""
def __init__(self, g, start, n=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
def dist_by_list_g_0_indexed(self):
"""
基于g是从0~n-1表示节点的图
:return: 距离数组代表start点到每个点的最短路
"""
dis, g = [inf] * self.n, self.g # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = [(0, self.start)] # 优先队列
while q:
c, u = heapq.heappop(q) # 当前点的最短路
if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
heapq.heappush(q, (d, v))
return dis # 距离数组
def dist_by_dict_g(self):
"""
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
"""
dis, g = {self.start: 0}, self.g # 初始化距离数组
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis.get(u, inf): continue
for v, w in g[u].items():
d = c + w
if d < dis.get(v, inf):
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist_by_default_dict_g(self):
"""优先用defaultdict版本,卡性能再考虑这个
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
"""
dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典
dis[self.start] = 0
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis[u]: continue
for v, w in g[u].items():
d = c + w
if d < dis[v]:
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist(self):
"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""
if isinstance(self.g, list):
return self.dist_by_list_g_0_indexed()
return self.dist_by_default_dict_g()
"""
class Solution:
def minimumObstacles(self, grid: List[List[int]]) -> int:
m,n = len(grid),len(grid[0])
g = defaultdict(dict)
for i in range(m):
for j in range(n):
for a,b in (i+1,j),(i,j+1):
if 0<=a<m and 0<=b<n:
g[i,j][a,b] = grid[a][b]
g[a,b][i,j] = grid[i][j]
dist = Dijkstra(g,(0,0)).dist()
return dist[m-1,n-1]
4.SPFA洛谷模板题
这题没给负权,数据小可以SPFA
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')
class Spfa:
"""
单源最短路,支持负权,复杂度O(m*n)
"""
def __init__(self, g, start, n=None, INF=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
self.INF = INF if INF is not None else inf
def dist_by_list_g_0_indexed(self):
"""
基于g是从0~n-1表示节点的图
:return: 距离数组代表start点到每个点的最短路
"""
dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = deque([(0, self.start)])
while q:
c, u = q.popleft() # 当前点的最短路
if c > dis[u]: continue
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
q.append((d, v))
return dis # 距离数组
def dist_by_dict_g(self):
"""优先用defaultdict版本,卡性能再考虑这个
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
"""
dis, g = {self.start: 0}, self.g # 初始化距离字典
INF = self.INF
q = deque([(0, self.start)])
while q:
c, u = q.popleft()
if c > dis.get(u, INF): continue
for v, w in g[u].items():
d = c + w
if d < dis.get(v, INF):
dis[v] = d
q.append((d, v))
return dis
def dist_by_default_dict_g(self):
"""
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
"""
dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典
dis[self.start] = 0
q = deque([(0, self.start)])
while q:
c, u = q.popleft()
if c > dis[u]: continue
for v, w in g[u].items():
d = c + w
if d < dis[v]:
dis[v] = d
q.append((d, v))
return dis
def dist(self):
"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""
if isinstance(self.g, list):
return self.dist_by_list_g_0_indexed()
return self.dist_by_default_dict_g()
def main():
n, m, s = RI()
g = collections.defaultdict(dict)
for _ in range(m):
u, v, w = RI()
a, b = u - 1, v - 1
if b in g[a]: # 在这wa了很多次,字典形式的图要处理重边
g[a][b] = min(g[a][b], w)
else:
g[a][b] = w
dis = Spfa(g, s - 1, INF=2 ** 31 - 1).dist()
ans = [0] * n
for i in range(n):
ans[i] = dis[i]
print(*ans)
# def main():
# n, m, s = RI()
# g = [[] for _ in range(n)]
# for _ in range(m):
# u, v, w = RI()
# g[u - 1].append((v - 1, w))
#
# dis = Spfa(g, s - 1).dist()
# for i, v in enumerate(dis):
# if v == inf:
# dis[i] = 2 ** 31 - 1
# print(*dis)
if __name__ == '__main__':
# testcase 2个字段分别是input和output
test_cases = (
(
"""4 6 1
1 2 2
2 3 2
2 4 1
1 3 5
3 4 3
1 4 4""", """0 2 4 3
"""
),
)
if os.path.exists('test.test'):
total_result = 'ok!'
for i, (in_data, result) in enumerate(test_cases):
result = result.strip()
with io.StringIO(in_data.strip()) as buf_in:
RI = lambda: map(int, buf_in.readline().split())
RS = lambda: buf_in.readline().strip().split()
with io.StringIO() as buf_out, redirect_stdout(buf_out):
main()
output = buf_out.getvalue().strip()
if output == result:
print(f'case{i}, result={result}, output={output}, ---ok!')
else:
print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
total_result = '---WA!---WA!---WA!'
print('\n', total_result)
else:
main()
5.SPFA判断负环洛谷模板题
链接: P3385 【模板】负环)
这题加边方式更搞笑,详细看题目
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')
class Spfa:
"""
单源最短路,支持负权,复杂度O(m*n)
"""
def __init__(self, g, start, n=None, INF=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
self.INF = INF if INF is not None else inf
def has_negative_circle(self):
"""
判断是否存在负环,即权为负的环,因为如果存在负环,路径每经过一次这个环就减小,永远出不来。
判断入队的点n次以上即出不来了,因为bellman-ford可知最多松弛n-1次就应该有答案
:return:
"""
dis, g, n = [self.INF] * self.n, self.g, self.n # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = deque([(0, self.start)])
cnt = [0] * n
while q:
c, u = q.popleft() # 当前点的最短路
cnt[u] += 1
if cnt[u] >= n:
return True # 有个点入队了n次以上,说明永远也结束不了,存在负环
if c > dis[u]: continue
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
q.append((d, v))
return False # 可以结束,不存在负环
def main():
T, = RI()
for _ in range(T):
n, m = RI()
g = [[] for _ in range(n)]
for _ in range(m):
u, v, w = RI()
if w >= 0:
g[v - 1].append((u - 1, w))
g[u - 1].append((v - 1, w))
if Spfa(g, 0).has_negative_circle():
print('YES')
else:
print('NO')
if __name__ == '__main__':
# testcase 2个字段分别是input和output
test_cases = (
(
"""2
3 4
1 2 2
1 3 4
2 3 1
3 1 -3
3 3
1 2 3
2 3 4
3 1 -8""", """NO
YES
"""
),
)
if os.path.exists('test.test'):
total_result = 'ok!'
for i, (in_data, result) in enumerate(test_cases):
result = result.strip()
with io.StringIO(in_data.strip()) as buf_in:
RI = lambda: map(int, buf_in.readline().split())
RS = lambda: buf_in.readline().strip().split()
with io.StringIO() as buf_out, redirect_stdout(buf_out):
main()
output = buf_out.getvalue().strip()
if output == result:
print(f'case{i}, result={result}, output={output}, ---ok!')
else:
print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
total_result = '---WA!---WA!---WA!'
print('\n', total_result)
else:
main()
6.Johnson带负环的全员最短路洛谷模板题
这题加边方式更搞笑,详细看题目
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')
class Spfa:
"""
单源最短路,支持负权,复杂度O(m*n)
"""
def __init__(self, g, start, n=None, INF=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
self.INF = INF if INF is not None else inf
def has_negative_circle(self):
"""
判断是否存在负环,即权为负的环,因为如果存在负环,路径每经过一次这个环就减小,永远出不来。
判断入队的点n次以上即出不来了,因为bellman-ford可知最多松弛n-1次就应该有答案
:return:
"""
dis, g, n = [self.INF] * self.n, self.g, self.n # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = deque([(0, self.start)])
cnt = [0] * n
while q:
c, u = q.popleft() # 当前点的最短路
cnt[u] += 1
if cnt[u] >= n:
return True # 有个点入队了n次以上,说明永远也结束不了,存在负环
if c > dis[u]: continue
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
q.append((d, v))
return False # 可以结束,不存在负环
def safe_dist_by_list_g_0_indexed(self):
"""
:return: 如果有负环返回空,否则正常返回距离数组
"""
dis, g, n = [self.INF] * self.n, self.g, self.n # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = deque([(0, self.start)])
cnt = [0] * n
while q:
c, u = q.popleft() # 当前点的最短路
cnt[u] += 1
if cnt[u] >= n:
return [] # 有个点入队了n次以上,说明永远也结束不了,存在负环
if c > dis[u]: continue
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
q.append((d, v))
return dis # 可以结束,不存在负环
def unsafe_dist_by_list_g_0_indexed(self):
"""
基于g是从0~n-1表示节点的图
:return: 距离数组代表start点到每个点的最短路
"""
dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = deque([(0, self.start)])
while q:
c, u = q.popleft() # 当前点的最短路
if c > dis[u]: continue
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
q.append((d, v))
return dis # 距离数组
def unsafe_dist_by_dict_g(self):
"""优先用defaultdict版本,卡性能再考虑这个
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
"""
dis, g = {self.start: 0}, self.g # 初始化距离字典
INF = self.INF
q = deque([(0, self.start)])
while q:
c, u = q.popleft()
if c > dis.get(u, INF): continue
for v, w in g[u].items():
d = c + w
if d < dis.get(v, INF):
dis[v] = d
q.append((d, v))
return dis
def unsafe_dist_by_default_dict_g(self):
"""
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
"""
dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典
dis[self.start] = 0
q = deque([(0, self.start)])
while q:
c, u = q.popleft()
if c > dis[u]: continue
for v, w in g[u].items():
d = c + w
if d < dis[v]:
dis[v] = d
q.append((d, v))
return dis
def unsafe_dist(self):
"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""
if isinstance(self.g, list):
return self.unsafe_dist_by_list_g_0_indexed()
return self.unsafe_dist_by_default_dict_g()
class Dijkstra:
"""
堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
其中INF可以自己定义,用来表示无法到达,默认是inf
"""
def __init__(self, g, start, n=None, INF=None):
self.n = len(g) if n is None else n
self.g = g
self.start = start
self.INF = INF if INF is not None else inf
def dist_by_list_g_0_indexed(self):
"""
基于g是从0~n-1表示节点的图
:return: 距离数组代表start点到每个点的最短路
"""
dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全inf
dis[self.start] = 0 # 源到自己距离0
q = [(0, self.start)] # 优先队列
while q:
c, u = heapq.heappop(q) # 当前点的最短路
if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
for v, w in g[u]: # 用u松弛它的邻居
d = c + w
if d < dis[v]: # 可以松弛
dis[v] = d
heapq.heappush(q, (d, v))
return dis # 距离数组
def dist_by_dict_g(self):
"""
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
"""
dis, g = {self.start: 0}, self.g # 初始化距离数组
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis.get(u, inf): continue
for v, w in g[u].items():
d = c + w
if d < dis.get(v, inf):
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist_by_default_dict_g(self):
"""优先用defaultdict版本,卡性能再考虑这个
基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
"""
dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典
dis[self.start] = 0
q = [(0, self.start)]
while q:
c, u = heapq.heappop(q)
if c > dis[u]: continue
for v, w in g[u].items():
d = c + w
if d < dis[v]:
dis[v] = d
heapq.heappush(q, (d, v))
return dis
def dist(self):
"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""
if isinstance(self.g, list):
return self.dist_by_list_g_0_indexed()
return self.dist_by_default_dict_g()
class Johnson:
"""
支持负权的全源最短路,复杂度M*NlogN
建立超级源点n,连接每个点,边权为0,然后SPFA对n求最短路,记为h(O(N*M))
修正原图的(u,v,w)边权为w+h[u]-h[v],这里类似前缀和/差分。用势能考虑
然后以每个点为起点,求n次最短路Dijkstra。最后把求得的最短路再修正回来(减去势能)
不支持INF参数,因为可能计算的最短路恰好等于INF,再修正就有问题;
- 这里只有safe版本,即顺便检查是否有负环;如果没有负边,那可以手写n次Dijkstra。
- 这么看难道普通Dijkstra也可以处理负边问题吗?是的,但是要先用SPFA预处理,那还不如直接SPFA
"""
def __init__(self, g, n=None):
self.n = len(g) if n is None else n
self.g = g
def safe_dist(self):
"""
如果存在负环返回空数组;否则返回二维数组dist[i][j]代表i到j的最短路
:return:
"""
g, n = self.g + [[]], self.n
for u in range(n): # 建立超级源点n,到达所有点,w=0
g[n].append((u, 0))
h = Spfa(g, n).safe_dist_by_list_g_0_indexed() # 对n点求最短路,如果有负环就返回
if not h: return h # 存在负环,别聊了
g.pop() # 删除超级源点
for u in range(n): # 把图中边权修正为w+h[u]-h[v]
p, g[u] = g[u], []
for v, w in p:
g[u].append((v, w + h[u] - h[v]))
ans = []
for u in range(n):
dis = Dijkstra(g, u).dist()
ans.append([d + h[v] - h[u] for v, d in enumerate(dis)]) # 把最短路修正回来
return ans
def main():
n, m = RI()
g = [[] for _ in range(n )]
for _ in range(m):
u, v, w = RI()
g[u - 1].append((v - 1, w))
INF = 10 ** 9
dises = Johnson(g).safe_dist()
# DEBUG(dises)
if not dises: return print(-1)
ans = []
for u, dis in enumerate(dises):
ans.append(sum(j * d if d < inf else j*INF for j, d in enumerate(dis,start=1)))
print(*ans, sep='\n')
if __name__ == '__main__':
# testcase 2个字段分别是input和output
test_cases = (
(
"""5 7
1 2 4
1 4 10
2 3 7
4 5 3
4 2 -2
3 4 -3
5 3 4""", """128
1000000072
999999978
1000000026
1000000014
"""
),
(
"""5 5
1 2 4
3 4 9
3 4 -3
4 5 3
5 3 -2""", """-1
"""
),
)
if os.path.exists('test.test'):
total_result = 'ok!'
for i, (in_data, result) in enumerate(test_cases):
result = result.strip()
with io.StringIO(in_data.strip()) as buf_in:
RI = lambda: map(int, buf_in.readline().split())
RS = lambda: buf_in.readline().strip().split()
with io.StringIO() as buf_out, redirect_stdout(buf_out):
main()
output = buf_out.getvalue().strip()
if output == result:
print(f'case{i}, result={result}, output={output}, ---ok!')
else:
print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
total_result = '---WA!---WA!---WA!'
print('\n', total_result)
else:
main()