浅谈图论——迪杰斯特拉算法(leetcode例题,C++演示)

本文介绍了图论中的迪杰斯特拉算法,包括算法原理、LeetCode上的网络延迟时间问题实例解析以及C++代码实现。通过实例展示了如何将单源最短路径问题转化为求解所有节点信号传播时间的问题。
摘要由CSDN通过智能技术生成

浅谈图论——迪杰斯特拉算法(leetcode例题,C++演示)

一、谈一谈图论

如果你想问这个世界上什么算法是最牛逼的?博主是回答不上来的。但是,如果你问博主什么数据结构是最牛逼?博主个人认为图是最牛逼的数据结构。因为很多的问题,都可以用图这种数据结构来表示。链表、树这种数据结构博主认为可以看成一种特殊的图。所以,博主今天就想探讨一下图论的经典算法——迪杰斯特拉算法,如果觉得有用的小伙伴可以点一个赞,爱学习的你们真棒!

二、什么是迪杰斯特拉算法?

我们不要被迪杰斯特拉算法的名字吓到了,其实迪杰斯特拉的算法并不复杂。很多时候,传统的并不复杂的算法往往在解决现实问题的时候有这良好的效果。首先我们要搞清楚迪杰斯特拉算法解决的单源最短路径的问题。单源最短路径问题是找到从图中的一个固定顶点(称为源点)到其他所有顶点的最短路径

迪杰斯特拉算法步骤如下:

  • 初始化: 将源点到其他顶点的距离初始化为无穷大,源点到自身的距离初始化为0。同时,维护一个集合 S,其中包含已确定最短路径的顶点,初始时 S 为空。

  • 选择距离最小的顶点: 从未确定最短路径的顶点中选择一个到源点距离最小的顶点,并将其标记为已确定最短路径。将该顶点加入集合 S 中。

  • 更新距离: 对于新确定的顶点,遍历其所有相邻的顶点,更新这些顶点到源点的最短路径估计值。如果通过新确定的顶点可以获得更短的路径,则更新相应顶点到源点的距离值。

  • 重复步骤2和步骤3: 重复选择距离最小的顶点和更新距离的步骤,直到所有顶点的最短路径都被确定。

迪杰斯特拉算法图示如下(转自知乎@鹅厂程序小哥)
在这里插入图片描述

原文链接:https://zhuanlan.zhihu.com/p/346558578

我们看出,迪杰斯特拉算法本质是一个贪心算法,也就是多个局部最优累计到全局最优。因为我们每次找的都是未选取的点通过已选取的点到源点的最短路径,或许还存在其他情况的更短路径。所以,每次得到一个新的最短的路径时,我们都需要将这个距离与现有的距离取小值作为最短的距离

三、典型例题讲解(leetcode)

为了更好的了解该算法,我们通过力扣的一道题目具体的实现迪杰斯特拉算法,题目链接

n 个网络节点,标记为 1n

给你一个列表 times,表示信号经过 有向 边的传递时间。 times[i] = (ui, vi, wi),其中 ui 是源节点,vi 是目标节点, wi 是一个信号从源节点传递到目标节点的时间。

现在,从某个节点 K 发出一个信号。需要多久才能使所有节点都收到信号?如果不能使所有节点收到信号,返回 -1

示例 1:

img

输入:times = [[2,1,1],[2,3,1],[3,4,1]], n = 4, k = 2
输出:2

示例 2:

输入:times = [[1,2,1]], n = 2, k = 1
输出:1

示例 3:

输入:times = [[1,2,1]], n = 2, k = 2
输出:-1

提示:

  • 1 <= k <= n <= 100
  • 1 <= times.length <= 6000
  • times[i].length == 3
  • 1 <= ui, vi <= n
  • ui != vi
  • 0 <= wi <= 100
  • 所有 (ui, vi) 对都 互不相同(即,不含重复边)

我们分析一下这个题目。首先。这是一个有向图问题,它希望我们找到从某一个源点发送信号到所有节点都收到信号的时间。那么,我们就可以把问题转化成单源最短路径问题。我们只需要取某一源点到其他节点的最短路径中的最大值,就可以解决这个例题。这是博主的思路,相当于用迪杰斯特拉算法找到源点到其他所有节点的最短路径,再进行比较,如果有更好的思路欢迎交流。

四、博主代码详解(C++)

接下来,我们要分析代码的实现了。这是博主自己写的代码,不是很成熟,AC是没问题的。首先奉上所有的代码,我们接下来对每一块进行分析。

class Solution {
public:
    //寻找距离
    vector<int> find_distance(vector<vector<int>>& times, vector<int>& flag,
                              int n, int k) {
        vector<int> distance(n, 10000);
        distance[k - 1] = 0;
        flag[k - 1] = 1;
        for (int i = 0; i < times.size(); i++) {
            if (times[i][0] == k) {
                distance[times[i][1] - 1] = times[i][2];
            }
        }
        return distance;
    }
    //更新距离
    vector<int> update_distance(vector<vector<int>>& times, vector<int>& flag,
                                vector<int>& distance, int n, int k) {
        vector<int> temp = find_distance(times, flag, n, k);
        int m;
        for (int i = 0; i < n; i++) {
            if (temp[i] < 10000 && flag[i]!=1) {
                m = distance[k-1] + temp[i];
                distance[i] = fmin(m, distance[i]);
            }
        }
        return distance;
    }
    //寻找下一个节点
    int find_next(vector<int> distance, vector<int>& flag, int n) {
        int min_index;
        for (int i = 0; i < n; i++) {
            if (flag[i] == 1)
                distance[i] = 10000;
        }
        min_index = min_element(distance.begin(), distance.end()) - distance.begin();
        if (distance[min_index] == 10000)
            return -1;
        return min_index + 1;
    }
    //实现的主逻辑函数
    int networkDelayTime(vector<vector<int>>& times, int n, int k) {
        int min_index;
        vector<int> distance;
        vector<int> flag(n, 0);
        distance = find_distance(times, flag, n, k);
        min_index = find_next(distance, flag, n);
        while (min_index != -1) {
            distance = update_distance(times, flag, distance, n, min_index);
            min_index = find_next(distance, flag, n);
        }
        auto it = find(flag.begin(), flag.end(), 0);
        if (it != flag.end())
            return -1;
        else
            return *max_element(distance.begin(), distance.end());
    }
};

这是整体的代码结构,我们接下会分析每个函数的作用。首先,我先声明一下flag相当于是一个n维全局变量,用于判断源点是否找到了对所有点的最短路径,0为未找到,1为找到

1、find_distance函数

vector<int> find_distance(vector<vector<int>>& times, vector<int>& flag,
                              int n, int k) {
        vector<int> distance(n, 10000);
        distance[k - 1] = 0;
        flag[k - 1] = 1;
        for (int i = 0; i < times.size(); i++) {
            if (times[i][0] == k) {
                distance[times[i][1] - 1] = times[i][2];
            }
        }
        return distance;
 }

首先,这个函数的作用是找到某一点到其他点的距离,如果是非相邻的节点,则距离为10000(在本题中相当于无限大)。我们要知道该函数输入四个变量,注意变量k是点的标号,不是索引。其次,该函数返回一个n维向量distance,具体的实现方法就是遍历。

注意,只要我们使用find_distance函数作用于某一个点,说明这个点已经存在一条到源点的路径,所以要执行flag[k - 1] = 1,讲到后面会明白的。

2、update_distance函数

vector<int> update_distance(vector<vector<int>>& times, vector<int>& flag,
                                vector<int>& distance, int n, int k) {
        vector<int> temp = find_distance(times, flag, n, k);
        int m;
        for (int i = 0; i < n; i++) {
            if (temp[i] < 10000 && flag[i]!=1) {
                m = distance[k-1] + temp[i];
                distance[i] = fmin(m, distance[i]);
            }
        }
        return distance;
}

update_distance函数传入五个变量,其中此distance非彼distance,这也相当于一个n维全局变量,是源点到其他点的距离。这个函数的作用就是通过调用find_distance函数作用于某一已经和源点之间存在最短路径的点,找到它和相邻点的距离,来更新源点到这个相邻点的距离。

distance[i] = fmin(m, distance[i])这一句很重要,我们需要取新得到的距离和原来的距离的较小值,防止局部最优影响到全局最优

3、find_next函数

int find_next(vector<int> distance, vector<int>& flag, int n) {
        int min_index;
        for (int i = 0; i < n; i++) {
            if (flag[i] == 1)
                distance[i] = 10000;
        }
        min_index = min_element(distance.begin(), distance.end()) - distance.begin();
        if (distance[min_index] == 10000)
            return -1;
        return min_index + 1;
}

find_next函数的作用就是寻找update_distance函数所作用的下一个点。它接受四个变量,其中distance和flag都相当于一个全局变量。这里的distance是一个形参,不是实参,方便节省空间。具体的实现方法就是把所有已经遍历的点的距离置10000,再返回最小距离对应的点的标号(不是索引,所以要加1)。

4、networkDelayTime函数

int networkDelayTime(vector<vector<int>>& times, int n, int k) {
        int min_index;
        vector<int> distance;
        vector<int> flag(n, 0);
        distance = find_distance(times, flag, n, k);
        min_index = find_next(distance, flag, n);
        while (min_index != -1) {
            distance = update_distance(times, flag, distance, n, min_index);
            min_index = find_next(distance, flag, n);
        }
        auto it = find(flag.begin(), flag.end(), 0);
        if (it != flag.end())
            return -1;
        else
            return *max_element(distance.begin(), distance.end());
}

这是整个程序的主函数,首先我们使用find_distance函数作用于源点,初始化diatance向量。然后,我们初始下一个要找的点的标号,注意这里的index不是索引,是标号

接着我们循环更新distance直到找不到min_index。如果flag向量含有0值,说明有的点到达不了,返回1。否则,返回distance中的最小值。这样,问题就解决了。

五、官方题解代码详解(C++)

博主的代码写的太长了,水平还不到家,仅供参考,我们来看官方题解。

class Solution {
public:
    int networkDelayTime(vector<vector<int>> &times, int n, int k) {
        const int inf = INT_MAX / 2;
        vector<vector<int>> g(n, vector<int>(n, inf));
        for (auto &t : times) {
            int x = t[0] - 1, y = t[1] - 1;
            g[x][y] = t[2];
        }

        vector<int> dist(n, inf);
        dist[k - 1] = 0;
        vector<int> used(n);
        for (int i = 0; i < n; ++i) {
            int x = -1;
            for (int y = 0; y < n; ++y) {
                if (!used[y] && (x == -1 || dist[y] < dist[x])) {
                    x = y;
                }
            }
            used[x] = true;
            for (int y = 0; y < n; ++y) {
                dist[y] = min(dist[y], dist[x] + g[x][y]);
            }
        }

        int ans = *max_element(dist.begin(), dist.end());
        return ans == inf ? -1 : ans;
    }
};
作者:力扣官方题解
链接:https://leetcode.cn/problems/network-delay-time/solutions/909575/wang-luo-yan-chi-shi-jian-by-leetcode-so-6phc/
来源:力扣(LeetCode)
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

官方题解首先把times转化维一个矩阵储存。官方题解的巧妙在于for循环的第一次肯定选取是源点,相当于对源点到其他点的距离做了初始化。

for (int y = 0; y < n; ++y) {
      if (!used[y] && (x == -1 || dist[y] < dist[x])) {
      x = y;
      }
}

其中,这一段代码是寻找下一个要处理的点。

for (int y = 0; y < n; ++y) {
     dist[y] = min(dist[y], dist[x] + g[x][y]);
}

这一段代码是更新距离。

循环总共找n次,保证了源点能到达的点都能找到。

for (int y = 0; y < n; ++y) {
      if (!used[y] && (x == -1 || dist[y] < dist[x])) {
      x = y;
      }
}

其中,这一段代码是寻找下一个要处理的点。

for (int y = 0; y < n; ++y) {
     dist[y] = min(dist[y], dist[x] + g[x][y]);
}

这一段代码是更新距离。

循环总共找n次,保证了源点能到达的点都能找到。

  • 20
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值