[python刷题模板] 最短路(Dijkstra/SPFA/Johnson)

一、 算法&数据结构

1. 描述

本文打一个最短路模板。
Dijkstra是nlogn+m处理无负权边情况
SPFA用n*m处理带负权边情况。
Johnson用n*mlogn处理全源最短路。(优于floyd)

2. 复杂度分析

  1. Dijkstra堆优化O(nlogn)。
  2. spfa O(nm)。
  3. Johnson O(mNlgn)

3. 常见应用

  1. dijkstra查询单源路径最短路,注意不能有负权边。
  2. spfa可以用来判断是否存在负环。

4. 常用优化

  1. 优先队列(堆)优化。这里注意没写vis已访问字典,由于最近的一定先访问,因此可以用dist>来代替。

二、 模板代码

1. Dijkstra洛谷模板题P4779 ,存在重边的单元最短路。

例题: P4779 【模板】单源最短路径(标准版)

  • 写了两个模板,一个用字典形式存图,目的是处理节点非编号的题目(比如矩阵搜索的节点用坐标表示)
  • 字典形式被重边折磨了,会覆盖,因此要注意处理。
  • 还有注意本题是有向图。
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')


class Dijkstra:
    """
    堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
    其中INF可以自己定义,用来表示无法到达,默认是inf
    """

    def __init__(self, g, start, n=None, INF=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start
        self.INF = INF if INF is not None else inf

    def dist_by_list_g_0_indexed(self):
        """
        基于g是从0~n-1表示节点的图
        :return: 距离数组代表start点到每个点的最短路
        """
        dis, g = [self.INF] * self.n, self.g  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = [(0, self.start)]  # 优先队列
        while q:
            c, u = heapq.heappop(q)  # 当前点的最短路
            if c > dis[u]: continue  # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis  # 距离数组

    def dist_by_dict_g(self):
        """
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
        """
        dis, g = {self.start: 0}, self.g  # 初始化距离数组

        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis.get(u, inf): continue
            for v, w in g[u].items():
                d = c + w
                if d < dis.get(v, inf):
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist_by_default_dict_g(self):
        """优先用defaultdict版本,卡性能再考虑这个
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
        """
        dis, g = defaultdict(lambda: self.INF), self.g  # 初始化距离字典
        dis[self.start] = 0
        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis[u]: continue
            for v, w in g[u].items():
                d = c + w
                if d < dis[v]:
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist(self):
        """根据g的类型自动判断是不是用下标代替节点,速度快一点"""
        if isinstance(self.g, list):
            return self.dist_by_list_g_0_indexed()
        return self.dist_by_default_dict_g()


def main():
    n, m, s = RI()

    g = collections.defaultdict(dict)
    for _ in range(m):
        u, v, w = RI()
        a, b = u - 1, v - 1
        if b in g[a]:  # 在这wa了很多次,字典形式的图要处理重边
            g[a][b] = min(g[a][b], w)
        else:
            g[a][b] = w

    dis = Dijkstra(g, s - 1).dist()
    ans = [0] * n
    for i in range(n):
        ans[i] = dis[i]
    print(*ans)


# def main():
#     n, m, s = RI()
#     g = [[] for _ in range(n)]
#     for _ in range(m):
#         u, v, w = RI()
#         g[u - 1].append((v - 1, w))
#
#     dis = Dijkstra(g, s - 1).dist()
#     print(*dis)


if __name__ == '__main__':
    # testcase 2个字段分别是input和output
    test_cases = (
        (
            """4 6 1
1 2 2
2 3 2
2 4 1
1 3 5
3 4 3
1 4 4""", """0 2 4 3
"""
        ),
    )
    if os.path.exists('test.test'):
        total_result = 'ok!'
        for i, (in_data, result) in enumerate(test_cases):
            result = result.strip()
            with io.StringIO(in_data.strip()) as buf_in:
                RI = lambda: map(int, buf_in.readline().split())
                RS = lambda: buf_in.readline().strip().split()
                with io.StringIO() as buf_out, redirect_stdout(buf_out):
                    main()
                    output = buf_out.getvalue().strip()
                if output == result:
                    print(f'case{i}, result={result}, output={output}, ---ok!')
                else:
                    print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
                    total_result = '---WA!---WA!---WA!'
        print('\n', total_result)
    else:
        main()

2. 转化为单源最短路

链接: 882. 细分图中的可到达节点

  • 把cnt+1看做边权,建图
  • 先求出每个点的最短路,然后ans累计在maxMoves步数下能到达的节点数,这样就把原式节点算完了。
  • 然后处理每条边上的cnt,对于(u,v,cnt):
    • 如果u当前剩余步数a = maxMoves-dist[u]>0,则可以从u向v访问a个节点,当然不能超过cnt
    • 同理处理v,如果v当前剩余步数b = maxMoves-dist[v]>0,则可以从v向u访问b个节点,当然不能超过cnt。
    • 最后,a+b也不可以超过cnt,即对于这条边:ans+=max(a+b,cnt)
class Dijkstra:
    """
    堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
    """
    def __init__(self, g, start, n=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start

    def dist_by_list_g_0_indexed(self):
        """
        基于g是从0~n-1表示节点的图
        :return: 距离数组代表start点到每个点的最短路
        """
        dis, g = [inf] * self.n, self.g  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = [(0, self.start)]  # 优先队列
        while q:
            c, u = heapq.heappop(q)  # 当前点的最短路
            if c > dis[u]: continue  # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis  # 距离数组

    def dist_by_dict_g(self):
        """
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
        """
        dis, g = {self.start: 0}, self.g  # 初始化距离数组

        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis.get(u, inf): continue
            for v, w in g[u].items():
                d = c + w
                if d < dis.get(v, inf):
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist_by_default_dict_g(self):
        """优先用defaultdict版本,卡性能再考虑这个
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
        """
        dis, g = defaultdict(lambda: self.INF), self.g  # 初始化距离字典
        dis[self.start] = 0
        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis[u]: continue
            for v, w in g[u].items():
                d = c + w
                if d < dis[v]:
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist(self):
        """根据g的类型自动判断是不是用下标代替节点,速度快一点"""
        if isinstance(self.g, list):
            return self.dist_by_list_g_0_indexed()
        return self.dist_by_default_dict_g()


class Solution:
    def reachableNodes(self, edges: List[List[int]], maxMoves: int, n: int) -> int: 
       
        g = [[] for _ in range(n)]
        for u,v,cnt in edges:
            g[u].append((v,cnt+1))
            g[v].append((u,cnt+1))
       
        dist = Dijkstra(g,0).dist()
        ans = sum(dist[x]<= maxMoves for x in range(n)) 
        for u, v, cnt in edges:
            a = max(maxMoves -  dist[u],0)
            b = max(maxMoves - dist[v],0)
            ans += min(cnt,a+b)
        return ans

3.矩阵搜索

链接: 2290. 到达角落需要移除障碍物的最小数目

这题用0-1bfs更快

class Dijkstra:
    """
    堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
    """
    def __init__(self, g, start, n=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start

    def dist_by_list_g_0_indexed(self):
        """
        基于g是从0~n-1表示节点的图
        :return: 距离数组代表start点到每个点的最短路
        """
        dis, g = [inf] * self.n, self.g  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = [(0, self.start)]  # 优先队列
        while q:
            c, u = heapq.heappop(q)  # 当前点的最短路
            if c > dis[u]: continue  # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis  # 距离数组

    def dist_by_dict_g(self):
        """
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
        """
        dis, g = {self.start: 0}, self.g  # 初始化距离数组

        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis.get(u, inf): continue
            for v, w in g[u].items():
                d = c + w
                if d < dis.get(v, inf):
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist_by_default_dict_g(self):
        """优先用defaultdict版本,卡性能再考虑这个
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
        """
        dis, g = defaultdict(lambda: self.INF), self.g  # 初始化距离字典
        dis[self.start] = 0
        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis[u]: continue
            for v, w in g[u].items():
                d = c + w
                if d < dis[v]:
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist(self):
        """根据g的类型自动判断是不是用下标代替节点,速度快一点"""
        if isinstance(self.g, list):
            return self.dist_by_list_g_0_indexed()
        return self.dist_by_default_dict_g()
"""
class Solution:
    def minimumObstacles(self, grid: List[List[int]]) -> int:
        m,n = len(grid),len(grid[0])
        g = defaultdict(dict)
        for i in range(m):
            for j in range(n):
                for a,b in (i+1,j),(i,j+1):
                    if 0<=a<m and 0<=b<n:
                        g[i,j][a,b] = grid[a][b]
                        g[a,b][i,j] = grid[i][j]
        dist = Dijkstra(g,(0,0)).dist()        
        return dist[m-1,n-1]

4.SPFA洛谷模板题

链接: P3371 【模板】单源最短路径(弱化版)

这题没给负权,数据小可以SPFA

import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')


class Spfa:
    """
    单源最短路,支持负权,复杂度O(m*n)
    """

    def __init__(self, g, start, n=None, INF=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start
        self.INF = INF if INF is not None else inf

    def dist_by_list_g_0_indexed(self):
        """
        基于g是从0~n-1表示节点的图
        :return: 距离数组代表start点到每个点的最短路
        """
        dis, g = [self.INF] * self.n, self.g  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = deque([(0, self.start)])
        while q:
            c, u = q.popleft()  # 当前点的最短路
            if c > dis[u]: continue
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛
                    dis[v] = d
                    q.append((d, v))
        return dis  # 距离数组

    def dist_by_dict_g(self):
        """优先用defaultdict版本,卡性能再考虑这个
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
        """
        dis, g = {self.start: 0}, self.g  # 初始化距离字典
        INF = self.INF
        q = deque([(0, self.start)])
        while q:
            c, u = q.popleft()
            if c > dis.get(u, INF): continue
            for v, w in g[u].items():
                d = c + w
                if d < dis.get(v, INF):
                    dis[v] = d
                    q.append((d, v))
        return dis

    def dist_by_default_dict_g(self):
        """
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
        """
        dis, g = defaultdict(lambda: self.INF), self.g  # 初始化距离字典
        dis[self.start] = 0
        q = deque([(0, self.start)])
        while q:
            c, u = q.popleft()
            if c > dis[u]: continue

            for v, w in g[u].items():
                d = c + w
                if d < dis[v]:
                    dis[v] = d
                    q.append((d, v))
        return dis

    def dist(self):
        """根据g的类型自动判断是不是用下标代替节点,速度快一点"""
        if isinstance(self.g, list):
            return self.dist_by_list_g_0_indexed()
        return self.dist_by_default_dict_g()


def main():
    n, m, s = RI()

    g = collections.defaultdict(dict)
    for _ in range(m):
        u, v, w = RI()
        a, b = u - 1, v - 1
        if b in g[a]:  # 在这wa了很多次,字典形式的图要处理重边
            g[a][b] = min(g[a][b], w)
        else:
            g[a][b] = w

    dis = Spfa(g, s - 1, INF=2 ** 31 - 1).dist()
    ans = [0] * n
    for i in range(n):
        ans[i] = dis[i]
    print(*ans)


# def main():
#     n, m, s = RI()
#     g = [[] for _ in range(n)]
#     for _ in range(m):
#         u, v, w = RI()
#         g[u - 1].append((v - 1, w))
#
#     dis = Spfa(g, s - 1).dist()
#     for i, v in enumerate(dis):
#         if v == inf:
#             dis[i] = 2 ** 31 - 1
#     print(*dis)


if __name__ == '__main__':
    # testcase 2个字段分别是input和output
    test_cases = (
        (
            """4 6 1
1 2 2
2 3 2
2 4 1
1 3 5
3 4 3
1 4 4""", """0 2 4 3
"""
        ),
    )
    if os.path.exists('test.test'):
        total_result = 'ok!'
        for i, (in_data, result) in enumerate(test_cases):
            result = result.strip()
            with io.StringIO(in_data.strip()) as buf_in:
                RI = lambda: map(int, buf_in.readline().split())
                RS = lambda: buf_in.readline().strip().split()
                with io.StringIO() as buf_out, redirect_stdout(buf_out):
                    main()
                    output = buf_out.getvalue().strip()
                if output == result:
                    print(f'case{i}, result={result}, output={output}, ---ok!')
                else:
                    print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
                    total_result = '---WA!---WA!---WA!'
        print('\n', total_result)
    else:
        main()

5.SPFA判断负环洛谷模板题

链接: P3385 【模板】负环)

这题加边方式更搞笑,详细看题目

import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')


class Spfa:
    """
    单源最短路,支持负权,复杂度O(m*n)
    """

    def __init__(self, g, start, n=None, INF=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start
        self.INF = INF if INF is not None else inf

    def has_negative_circle(self):
        """
        判断是否存在负环,即权为负的环,因为如果存在负环,路径每经过一次这个环就减小,永远出不来。
        判断入队的点n次以上即出不来了,因为bellman-ford可知最多松弛n-1次就应该有答案
        :return:
        """
        dis, g, n = [self.INF] * self.n, self.g, self.n  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = deque([(0, self.start)])
        cnt = [0] * n
        while q:
            c, u = q.popleft()  # 当前点的最短路
            cnt[u] += 1
            if cnt[u] >= n:
                return True  # 有个点入队了n次以上,说明永远也结束不了,存在负环
            if c > dis[u]: continue
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛

                    dis[v] = d
                    q.append((d, v))
        return False  # 可以结束,不存在负环

def main():
    T, = RI()
    for _ in range(T):
        n, m = RI()
        g = [[] for _ in range(n)]
        for _ in range(m):
            u, v, w = RI()
            if w >= 0:
                g[v - 1].append((u - 1, w))
            g[u - 1].append((v - 1, w))

        if Spfa(g, 0).has_negative_circle():
            print('YES')
        else:
            print('NO')


if __name__ == '__main__':
    # testcase 2个字段分别是input和output
    test_cases = (
        (
            """2
3 4
1 2 2
1 3 4
2 3 1
3 1 -3
3 3
1 2 3
2 3 4
3 1 -8""", """NO
YES
"""
        ),
    )
    if os.path.exists('test.test'):
        total_result = 'ok!'
        for i, (in_data, result) in enumerate(test_cases):
            result = result.strip()
            with io.StringIO(in_data.strip()) as buf_in:
                RI = lambda: map(int, buf_in.readline().split())
                RS = lambda: buf_in.readline().strip().split()
                with io.StringIO() as buf_out, redirect_stdout(buf_out):
                    main()
                    output = buf_out.getvalue().strip()
                if output == result:
                    print(f'case{i}, result={result}, output={output}, ---ok!')
                else:
                    print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
                    total_result = '---WA!---WA!---WA!'
        print('\n', total_result)
    else:
        main()

6.Johnson带负环的全员最短路洛谷模板题

链接: P5905 【模板】Johnson 全源最短路

这题加边方式更搞笑,详细看题目

import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')


class Spfa:
    """
    单源最短路,支持负权,复杂度O(m*n)
    """

    def __init__(self, g, start, n=None, INF=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start
        self.INF = INF if INF is not None else inf

    def has_negative_circle(self):
        """
        判断是否存在负环,即权为负的环,因为如果存在负环,路径每经过一次这个环就减小,永远出不来。
        判断入队的点n次以上即出不来了,因为bellman-ford可知最多松弛n-1次就应该有答案
        :return:
        """
        dis, g, n = [self.INF] * self.n, self.g, self.n  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = deque([(0, self.start)])
        cnt = [0] * n
        while q:
            c, u = q.popleft()  # 当前点的最短路
            cnt[u] += 1
            if cnt[u] >= n:
                return True  # 有个点入队了n次以上,说明永远也结束不了,存在负环
            if c > dis[u]: continue
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛

                    dis[v] = d
                    q.append((d, v))
        return False  # 可以结束,不存在负环

    def safe_dist_by_list_g_0_indexed(self):
        """
        :return: 如果有负环返回空,否则正常返回距离数组
        """
        dis, g, n = [self.INF] * self.n, self.g, self.n  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = deque([(0, self.start)])
        cnt = [0] * n
        while q:
            c, u = q.popleft()  # 当前点的最短路
            cnt[u] += 1
            if cnt[u] >= n:
                return []  # 有个点入队了n次以上,说明永远也结束不了,存在负环
            if c > dis[u]: continue
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛

                    dis[v] = d
                    q.append((d, v))
        return dis  # 可以结束,不存在负环

    def unsafe_dist_by_list_g_0_indexed(self):
        """
        基于g是从0~n-1表示节点的图
        :return: 距离数组代表start点到每个点的最短路
        """
        dis, g = [self.INF] * self.n, self.g  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = deque([(0, self.start)])
        while q:
            c, u = q.popleft()  # 当前点的最短路
            if c > dis[u]: continue
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛
                    dis[v] = d
                    q.append((d, v))
        return dis  # 距离数组

    def unsafe_dist_by_dict_g(self):
        """优先用defaultdict版本,卡性能再考虑这个
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
        """
        dis, g = {self.start: 0}, self.g  # 初始化距离字典
        INF = self.INF
        q = deque([(0, self.start)])
        while q:
            c, u = q.popleft()
            if c > dis.get(u, INF): continue
            for v, w in g[u].items():
                d = c + w
                if d < dis.get(v, INF):
                    dis[v] = d
                    q.append((d, v))
        return dis

    def unsafe_dist_by_default_dict_g(self):
        """
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
        """
        dis, g = defaultdict(lambda: self.INF), self.g  # 初始化距离字典
        dis[self.start] = 0
        q = deque([(0, self.start)])
        while q:
            c, u = q.popleft()
            if c > dis[u]: continue

            for v, w in g[u].items():
                d = c + w
                if d < dis[v]:
                    dis[v] = d
                    q.append((d, v))
        return dis

    def unsafe_dist(self):
        """根据g的类型自动判断是不是用下标代替节点,速度快一点"""
        if isinstance(self.g, list):
            return self.unsafe_dist_by_list_g_0_indexed()
        return self.unsafe_dist_by_default_dict_g()


class Dijkstra:
    """
    堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn
    其中INF可以自己定义,用来表示无法到达,默认是inf
    """

    def __init__(self, g, start, n=None, INF=None):
        self.n = len(g) if n is None else n
        self.g = g
        self.start = start
        self.INF = INF if INF is not None else inf

    def dist_by_list_g_0_indexed(self):
        """
        基于g是从0~n-1表示节点的图
        :return: 距离数组代表start点到每个点的最短路
        """
        dis, g = [self.INF] * self.n, self.g  # 初始化距离数据为全inf
        dis[self.start] = 0  # 源到自己距离0
        q = [(0, self.start)]  # 优先队列
        while q:
            c, u = heapq.heappop(q)  # 当前点的最短路
            if c > dis[u]: continue  # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。
            for v, w in g[u]:  # 用u松弛它的邻居
                d = c + w
                if d < dis[v]:  # 可以松弛
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis  # 距离数组

    def dist_by_dict_g(self):
        """
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。
        """
        dis, g = {self.start: 0}, self.g  # 初始化距离数组

        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis.get(u, inf): continue
            for v, w in g[u].items():
                d = c + w
                if d < dis.get(v, inf):
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist_by_default_dict_g(self):
        """优先用defaultdict版本,卡性能再考虑这个
        基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃
        由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。
        :return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf
        """
        dis, g = defaultdict(lambda: self.INF), self.g  # 初始化距离字典
        dis[self.start] = 0
        q = [(0, self.start)]
        while q:
            c, u = heapq.heappop(q)
            if c > dis[u]: continue
            for v, w in g[u].items():
                d = c + w
                if d < dis[v]:
                    dis[v] = d
                    heapq.heappush(q, (d, v))
        return dis

    def dist(self):
        """根据g的类型自动判断是不是用下标代替节点,速度快一点"""
        if isinstance(self.g, list):
            return self.dist_by_list_g_0_indexed()
        return self.dist_by_default_dict_g()


class Johnson:
    """
    支持负权的全源最短路,复杂度M*NlogN
    建立超级源点n,连接每个点,边权为0,然后SPFA对n求最短路,记为h(O(N*M))
    修正原图的(u,v,w)边权为w+h[u]-h[v],这里类似前缀和/差分。用势能考虑
    然后以每个点为起点,求n次最短路Dijkstra。最后把求得的最短路再修正回来(减去势能)
    不支持INF参数,因为可能计算的最短路恰好等于INF,再修正就有问题;
    - 这里只有safe版本,即顺便检查是否有负环;如果没有负边,那可以手写n次Dijkstra。
    - 这么看难道普通Dijkstra也可以处理负边问题吗?是的,但是要先用SPFA预处理,那还不如直接SPFA
    """
    def __init__(self, g, n=None):
        self.n = len(g) if n is None else n
        self.g = g

    def safe_dist(self):
        """
        如果存在负环返回空数组;否则返回二维数组dist[i][j]代表i到j的最短路
        :return:
        """
        g, n = self.g + [[]], self.n

        for u in range(n):  # 建立超级源点n,到达所有点,w=0
            g[n].append((u, 0))
        h = Spfa(g, n).safe_dist_by_list_g_0_indexed()  # 对n点求最短路,如果有负环就返回
        if not h: return h  # 存在负环,别聊了
        g.pop()  # 删除超级源点

        for u in range(n):  # 把图中边权修正为w+h[u]-h[v]
            p, g[u] = g[u], []
            for v, w in p:
                g[u].append((v, w + h[u] - h[v]))
        ans = []
        for u in range(n):
            dis = Dijkstra(g, u).dist()
            ans.append([d + h[v] - h[u] for v, d in enumerate(dis)])  # 把最短路修正回来
        return ans


def main():
    n, m = RI()
    g = [[] for _ in range(n )]
    for _ in range(m):
        u, v, w = RI()
        g[u - 1].append((v - 1, w))
    INF = 10 ** 9
    dises = Johnson(g).safe_dist()
    # DEBUG(dises)
    if not dises: return print(-1)
    ans = []
    for u, dis in enumerate(dises):
        ans.append(sum(j * d if d < inf else j*INF for j, d in enumerate(dis,start=1)))
    print(*ans, sep='\n')


if __name__ == '__main__':
    # testcase 2个字段分别是input和output
    test_cases = (
        (
            """5 7
1 2 4
1 4 10
2 3 7
4 5 3
4 2 -2
3 4 -3
5 3 4""", """128
1000000072
999999978
1000000026
1000000014
"""
        ),
        (
            """5 5
1 2 4
3 4 9
3 4 -3
4 5 3
5 3 -2""", """-1
"""
        ),
    )
    if os.path.exists('test.test'):
        total_result = 'ok!'
        for i, (in_data, result) in enumerate(test_cases):
            result = result.strip()
            with io.StringIO(in_data.strip()) as buf_in:
                RI = lambda: map(int, buf_in.readline().split())
                RS = lambda: buf_in.readline().strip().split()
                with io.StringIO() as buf_out, redirect_stdout(buf_out):
                    main()
                    output = buf_out.getvalue().strip()
                if output == result:
                    print(f'case{i}, result={result}, output={output}, ---ok!')
                else:
                    print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')
                    total_result = '---WA!---WA!---WA!'
        print('\n', total_result)
    else:
        main()

三、其他

四、更多例题

五、参考链接

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值