最近要给俱乐部的成员做板子,所以专门写了一些代码方便理解背诵。代码注释很详细,并且对时间复杂度进行了详细的推导,希望大家看完后会有所收获。
# -*- coding : utf-8 -*-
"""
@author: 2023数据赋能俱乐部
@Description: 迪杰斯特拉算法
@Date: 2024-4-4 15:36
"""
import random
from heapq import heappop, heappush
from math import inf
from typing import List
# 本代码可以直接提交力扣2642测评
# https://leetcode.cn/problems/design-graph-with-shortest-path-calculator
class Graph:
def __init__(self, n: int, edges: List[List[int]]): # edges为有向边三元组(x, y, 权重)
a = [[] for _ in range(n)]
self.a = a; self.n = n
# 邻接表
for x, y, w in edges:
a[x].append((w, y))
# 邻接矩阵
self.g = [[inf] * n for _ in range(n)]
for x, y, w in edges:
self.g[x][y] = w # 添加一条边
def addEdge(self, e: List[int]) -> None:
x, y, w = e
# 邻接表
self.a[x].append((w, y))
# 邻接矩阵
self.g[x][y] = w
def dijkstra_pq(self, node1: int, node2: int): # 返回最短路径,找不到返回-1
# m为边数,n为顶点数
# dijkstra堆优化版 时间复杂度 O(mlogm)
n = self.n
edges = [(0, node1)]
# 堆优化版可以不使用vis数组
vis = [0] * n # 此行代码可注释,仅用于断言证明
dis = [inf] * n; dis[node1] = 0
xs = set() # 此行代码可注释,仅用于断言证明
# 以上代码复杂度O(n)
# 以下循环分为2部分计算复杂度
while edges: # 注意:这个循环最多执行m次(可自行断言证明)
# Part ①
d, x = heappop(edges) # 最多调用m次,每次logm
if x == node2:
return dis[node2]
""" if dis[x] < d 和if vis[x] 完全等价,但vis需要O(n)初始化 """
""" 两者均可使用 """
# 断言证明以上结论:
assert ((dis[x] < d) == vis[x]); vis[x] = 1 # 此行代码可注释,仅用于断言证明
if dis[x] < d: # 注意:这里不能取等!!必须是严格小于
continue
# Part ②
""" 到这里,以下代码最多执行n次,超过n小于等于m的情况被上面continue掉了(可自行断言证明)。"""
# 且每次的x都不一样
assert x not in xs; xs.add(x) # 断言证明:每次的x都不一样。可注释
for w, y in self.a[x]:
if dis[x] + w < dis[y]:
dis[y] = dis[x] + w
# 极端情况:所有边都是x的出边,以下代码直接执行m次
# 所以,堆大小最大为m,这就是heap单次操作O(logm)的原因
heappush(edges, (dis[y], y)) # 该语句总共最多执行m次,单次O(logm)
# 解释以上语句为什么总共最多执行m次:
# 此for循环每次x都不同,最多n次循环,所以最坏会遍历所有顶点的所有边,即为m。
# 以上while循环复杂度:
# Part①: O(mlogm)
# Part②: O(n) + O(mlogm) 前面的O(n)来源于for里面的if语句,后者来源于入堆
# 所以,总复杂度为:
# 初始化复杂度+while循环复杂度
# = O(n) + O(mlogm) + O(n) + O(mlogm)
# = O(n) + O(mlogm) ③
# 特别地,一般有:m >= n - 1 (无孤立结点的最少边数)
# 所以,可认为 n < mlogm
# 那么③可化简为:
# O(mlogm)
# 这便是dijkstra在堆优化下的时间复杂度。
return -1
def dijkstra_bf(self, node1: int, node2: int): # 返回最短路径,找不到返回-1
# m为边数,n为顶点数
# dijkstra朴素版 时间复杂度 O(n^2)
n = self.n
# 朴素版必须使用vis数组
vis = [0] * n
# 朴素版和堆优化都需要dis数组,和初始化出发点dis为0的操作。
dis = [inf] * n; dis[node1] = 0
# 无脑n次循环(每次循环会确定一个最短路径的结点)
for _ in range(n):
# 先找到dis里最小且未访问的
min_d = inf; x = 0
for i in range(len(dis)):
if not vis[i] and dis[i] < min_d:
min_d = dis[i]
x = i
# 如果是目标结点,提前返回
if x == node2:
return dis[x] if dis[x] != inf else -1
# 标记为已访问(每一次访问都会找到一个最短路径的结点)
vis[x] = 1
# 用最小的做中转,尝试更新他的邻居结点
for y in range(n):
# 不要用min 速度会变慢非常多
if dis[x] + self.g[x][y] < dis[y]:
dis[y] = dis[x] + self.g[x][y]
# 没找到
return -1
def shortestPath(self, node1: int, node2: int):
c = random.randint(0, 1)
# 随机调用测试
if c == 0:
return self.dijkstra_bf(node1, node2)
else:
return self.dijkstra_pq(node1, node2)
if __name__ == '__main__':
# 图的顶点数和边
n = 5
edges = [
(0, 1, 10),
(0, 2, 3),
(1, 2, 1),
(1, 3, 2),
(2, 1, 4),
(2, 3, 8),
(2, 4, 2),
(3, 4, 7),
(4, 3, 9)
]
# 创建图对象
graph = Graph(n, edges)
# 要测试的起点和终点
node1 = 0
node2 = 3
# 调用 shortestPath 方法计算最短路径
shortest_distance = graph.shortestPath(node1, node2)
print(f"The shortest distance from node {node1} to node {node2} is {shortest_distance}.")