[题解] 最短路计数 Dijkstra
已知一个有向图,
n
n
n个点
m
m
m条边,现在要询问从
1
1
1到
n
n
n的最短路一共有几条。
不得不说,这道题让我加深了对Dijkstra算法的理解。
我们可以直接在原板子上稍作改进。
刚开始的时候想到的是以拓扑序排序,但是其实对于最短路问题来说,Dijkstra算法更新的顺序就是“拓扑序”。因为我们每次先更新的点
u
u
u的最短路
d
[
u
]
d[u]
d[u]都比后面更新的点
v
v
v的最短路
d
[
v
]
d[v]
d[v]更小。
首先,我们知道Dij算法维护的是已经确定最短路的点的集合。
那么,当我们每次从这个集合之外取出离集合最近的点
t
t
t之后,我们都会用这个点去更新周围的点,并把成功更新的点入队。
那么对于这个点
t
t
t周围的点
v
v
v,我们只讨论当他作为点的前驱时对答案产生的贡献。
设从点
1
1
1到点
t
t
t的最短路条数为
c
n
t
[
t
]
cnt[t]
cnt[t]。
如果
d
[
v
]
>
d
[
t
]
+
w
[
t
]
[
i
]
d[v] > d[t] + w[t][i]
d[v]>d[t]+w[t][i],那么说明之前
t
t
t保存的最短路是错误的,我们用
c
n
t
[
v
]
cnt[v]
cnt[v]来更新它。
否则如果
d
[
v
]
=
=
d
[
t
]
+
w
[
t
]
[
i
]
d[v] == d[t] + w[t][i]
d[v]==d[t]+w[t][i],说明此时
v
v
v多了一条最短路。
这个操作会不会重复更新节点呢?答案是不会。因为对于每个前驱节点
u
u
u,我们只会对每个后继节点
v
v
v更新一次。
再次强调,这里的“前驱”和“后继”指前面的最短路比后面的小。
这样不断更新,最后得到的值一定是最短路。思想就类似数学归纳法,这里就不给出严格证明了。
注意由于这是一道和“路径个数”有关的问题,所以我们必须去重。
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
using namespace std;
const double eps = 1e-10;
const double pi = acos(-1.0);
const int maxn = 2010;
int n,m;
vector<int> g[maxn],w[maxn];
bool vis[maxn];
int app[2010][2010];
int d[maxn],cnt[maxn];
priority_queue<pii,vector<pii>,greater<pii>> q;
/*
思路:
我们在dijkstra的模板上稍加改进。
每次用已经确定最短路的点u更新周围的点v的时候,
如果发现d[v] > d[u] + w[u][i]时,那么之前d[v]记录的一定不是最短路,
现在我们只需要把它覆盖掉就彳亍了。
cnt[v] = cnt[u]
那么,如果我们发现d[v] = d[u] + w[u][i]的时候,
我们就发现了一条现在看来是最短路的另一条路,
cnt[v] += cnt[u]。
!!注意!对于一般的图,我们直接用邻接表即可。但是对于路径统计问题,
我们必须要除掉自环和重边。(否则路径会变多,比如1 2 1 ;1 2 1)
*/
void solve(){
scanf("%d%d",&n,&m);
while(m--){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
if(!app[a][b] || app[a][b] > c){
app[a][b] = c;
g[a].push_back(b);
w[a].push_back(c);
}
}
memset(d,0x3f,sizeof(d));
d[1] = 0;
cnt[1] = 1;
q.push({0,1});
while(!q.empty()){
int u = q.top().second;
q.pop();
if(vis[u]) continue;
vis[u] = 1;
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];
if(d[v] == d[u] + w[u][i]){
cnt[v] += cnt[u];
}
else if(d[v] > d[u] + w[u][i]){
d[v] = d[u] + w[u][i];
cnt[v] = cnt[u];
if(!vis[v])q.push({d[v],v});
}
}
}
if(d[n] >= 0x3f3f3f3f) puts("No answer");
else printf("%d %d\n",d[n],cnt[n]);
}
int main()
{
solve();
return 0;
}