题目大意
正确的floyd写法
for k from 1 to n
for i from 1 to n
for j from 1 to n
dis[i][j] <- min(dis[i][j], dis[i][k] + dis[k][j])
错误的floyd写法
for i from 1 to n
for j from 1 to n
for k from 1 to n
dis[i][j] <- min(dis[i][j], dis[i][k] + dis[k][j])
给一个图,求跑错误写法后依旧正确的dist[i][j]的数量
有两种情况是正确的
1:i和j之间存在一条边边权w,w等于用正确做法跑出来的最短路答案(他们之间的一个边就是最短路,dist[i][j]==w)
2:存在一个点Z,且dis[X][Z]和dis[Z][Y]都满足情况1,且Z位于X到Y的一个最短路上
做法:
用dist[N][N]记录正确的跑法,用g[N][N]记录错误的跑法
1 记录好所有数据
2 对所有点分别跑一遍dijkstra,设当前跑的点是P,计算该点到其他所有点的最短路,跑完后获得数组dist[P][N] ,然后从1到n遍历,当dist[p][i]==g[p][i]
时,这两个点就满足上文说的第一种情况
3 将下标为P的邻接表清空,随后在下标为P的邻接表中放入步骤2中的那些满足条件dist[p][i]==g[p][i]
的点及最短路,为判断第二种情况做铺垫。
4 再对所有点跑一边dijkstra.此时;邻接表所储存的边都是满足情况1的边,由这些边跑出来的结果就满足条件2
判断dist数组和g数组中有多少点相同,就是答案
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int >pii;
typedef vector<pii>vpi;
#define mp make_pair
#define pb push_back
const int maxn = 2020;
const int inf = 0x3f3f3f3f;
priority_queue<pii, vector<pii>, greater<pii>>q; //优先队列,小的放前面,用来算dijkstra
int n, m, ans;
int g[maxn][maxn];//记录用floyd跑出的错误答案
int dis[maxn][maxn];
vector<pii> e[maxn];//邻接表
void dij(int p) {//正确floyd的答案
for (int i = 1; i <= n; i++)
dis[p][i] = inf;
dis[p][p] = 0;
for (int i = 0; i < e[p].size(); i++)
q.push(e[p][i]);//与p最近的点在top
while (q.size()) {
pii t = q.top();
q.pop();
int u = t.second;//first 是距离 second是目标点
if (dis[p][u] != inf)
continue;
dis[p][u] = t.first;
for (int i = 0; i < e[u].size(); i++) {
int v = e[u][i].second;
if (dis[p][v] != inf)
continue;
q.push({e[u][i].first + dis[p][u], v});
}
}
e[p].clear();//清除掉
for (int i = 1; i <= n; i++) {
//p和i之间有一条直接相连的路,而且该路是最短的
if (dis[p][i] == g[p][i] && dis[p][i] < inf && p != i)
e[p].push_back({dis[p][i], i});
}
}
void solve(int p) {//错误的floyd答案
for (int i = 0; i < e[p].size(); i++)
q.push({0, e[p][i].second});
//first k的值 second 目标点的编号
while (q.size()) {
pii t = q.top();
q.pop();
int id = t.first;
int u = t.second;
for (int i = 0; i < e[u].size(); i++) {
int v = e[u][i].second;//编号
int w = e[u][i].first;//距离
if (v < id)//如果目标点编号小于k的编号(只能往后更新,不能往前更新)
continue;
//正确答案等于错误答案 或者两点不连通
if (dis[p][v] == g[p][v] || dis[p][v] >= inf)
continue;
//如果正确答案按错误方法做也正确
if (dis[p][v] == dis[p][u] + w) {
q.push({v, v});
e[p].push_back({dis[p][v], v});//p点与v点之间就存在了这么一个正确的边
g[p][v] = dis[p][v];
}
}
}
}
int main() {
cin >> n >> m;
memset(g, 0x3f, sizeof g);
for (int i = 1; i <= m; i++) {
int u, v, w;
cin >> u >> v >> w;
e[u].push_back({w, v});
g[u][v] = w;
}
for (int i = 1; i <= n; i++)
g[i][i] = 0;
for (int i = 1; i <= n; i++)
dij(i);//对每个点求最近距离
for (int i = 1; i <= n; i++)
solve(i);
ans = 0;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
if (g[i][j] == dis[i][j] || g[i][j] >= inf && dis[i][j] >= inf) {
ans++;
}
}
}
cout << ans << endl;
}