(Python版)Dijkstra算法(包含朴素算法和堆优化)

最近要给俱乐部的成员做板子,所以专门写了一些代码方便理解背诵。代码注释很详细,并且对时间复杂度进行了详细的推导,希望大家看完后会有所收获。

# -*- coding : utf-8 -*-
"""
@author: 2023数据赋能俱乐部
@Description: 迪杰斯特拉算法
@Date: 2024-4-4 15:36
"""
import random
from heapq import heappop, heappush
from math import inf
from typing import List

# 本代码可以直接提交力扣2642测评
# https://leetcode.cn/problems/design-graph-with-shortest-path-calculator


class Graph:
    def __init__(self, n: int, edges: List[List[int]]):  # edges为有向边三元组(x, y, 权重)
        a = [[] for _ in range(n)]
        self.a = a; self.n = n

        # 邻接表
        for x, y, w in edges:
            a[x].append((w, y))

        # 邻接矩阵
        self.g = [[inf] * n for _ in range(n)]
        for x, y, w in edges:
            self.g[x][y] = w  # 添加一条边

    def addEdge(self, e: List[int]) -> None:
        x, y, w = e

        # 邻接表
        self.a[x].append((w, y))

        # 邻接矩阵
        self.g[x][y] = w

    def dijkstra_pq(self, node1: int, node2: int):  # 返回最短路径,找不到返回-1
        # m为边数,n为顶点数
        # dijkstra堆优化版 时间复杂度 O(mlogm)
        n = self.n
        edges = [(0, node1)]
        # 堆优化版可以不使用vis数组
        vis = [0] * n  # 此行代码可注释,仅用于断言证明
        dis = [inf] * n; dis[node1] = 0
        xs = set()  # 此行代码可注释,仅用于断言证明
        # 以上代码复杂度O(n)
        # 以下循环分为2部分计算复杂度
        while edges:  # 注意:这个循环最多执行m次(可自行断言证明)
            # Part ①
            d, x = heappop(edges)  # 最多调用m次,每次logm
            if x == node2:
                return dis[node2]
            """ if dis[x] < d 和if vis[x] 完全等价,但vis需要O(n)初始化 """
            """ 两者均可使用 """
            # 断言证明以上结论:
            assert ((dis[x] < d) == vis[x]); vis[x] = 1  # 此行代码可注释,仅用于断言证明
            if dis[x] < d:  # 注意:这里不能取等!!必须是严格小于
                continue
            # Part ②
            """ 到这里,以下代码最多执行n次,超过n小于等于m的情况被上面continue掉了(可自行断言证明)。"""
            # 且每次的x都不一样
            assert x not in xs; xs.add(x)  # 断言证明:每次的x都不一样。可注释
            for w, y in self.a[x]:
                if dis[x] + w < dis[y]:
                    dis[y] = dis[x] + w
                    # 极端情况:所有边都是x的出边,以下代码直接执行m次
                    # 所以,堆大小最大为m,这就是heap单次操作O(logm)的原因
                    heappush(edges, (dis[y], y))  # 该语句总共最多执行m次,单次O(logm)
                    # 解释以上语句为什么总共最多执行m次:
                    # 此for循环每次x都不同,最多n次循环,所以最坏会遍历所有顶点的所有边,即为m。
        # 以上while循环复杂度:
        # Part①: O(mlogm)
        # Part②: O(n) + O(mlogm)  前面的O(n)来源于for里面的if语句,后者来源于入堆
        # 所以,总复杂度为:
        # 初始化复杂度+while循环复杂度
        # = O(n) + O(mlogm) + O(n) + O(mlogm)
        # = O(n) + O(mlogm) ③
        # 特别地,一般有:m >= n - 1 (无孤立结点的最少边数)
        # 所以,可认为 n < mlogm
        # 那么③可化简为:
        # O(mlogm)
        # 这便是dijkstra在堆优化下的时间复杂度。

        return -1

    def dijkstra_bf(self, node1: int, node2: int):  # 返回最短路径,找不到返回-1
        # m为边数,n为顶点数
        # dijkstra朴素版 时间复杂度 O(n^2)
        n = self.n
        # 朴素版必须使用vis数组
        vis = [0] * n
        # 朴素版和堆优化都需要dis数组,和初始化出发点dis为0的操作。
        dis = [inf] * n; dis[node1] = 0
        # 无脑n次循环(每次循环会确定一个最短路径的结点)
        for _ in range(n):
            # 先找到dis里最小且未访问的
            min_d = inf; x = 0
            for i in range(len(dis)):
                if not vis[i] and dis[i] < min_d:
                    min_d = dis[i]
                    x = i
            # 如果是目标结点,提前返回
            if x == node2:
                return dis[x] if dis[x] != inf else -1
            # 标记为已访问(每一次访问都会找到一个最短路径的结点)
            vis[x] = 1
            # 用最小的做中转,尝试更新他的邻居结点
            for y in range(n):
                # 不要用min 速度会变慢非常多
                if dis[x] + self.g[x][y] < dis[y]:
                    dis[y] = dis[x] + self.g[x][y]
        # 没找到
        return -1

    def shortestPath(self, node1: int, node2: int):
        c = random.randint(0, 1)
        # 随机调用测试
        if c == 0:
            return self.dijkstra_bf(node1, node2)
        else:
            return self.dijkstra_pq(node1, node2)


if __name__ == '__main__':
    # 图的顶点数和边
    n = 5
    edges = [
        (0, 1, 10),
        (0, 2, 3),
        (1, 2, 1),
        (1, 3, 2),
        (2, 1, 4),
        (2, 3, 8),
        (2, 4, 2),
        (3, 4, 7),
        (4, 3, 9)
    ]

    # 创建图对象
    graph = Graph(n, edges)

    # 要测试的起点和终点
    node1 = 0
    node2 = 3

    # 调用 shortestPath 方法计算最短路径
    shortest_distance = graph.shortestPath(node1, node2)

    print(f"The shortest distance from node {node1} to node {node2} is {shortest_distance}.")

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值