BFS - 最短路计数 - 洛谷 P1144
给出一个 N 个顶点 M 条边的无向无权图,顶点编号为 1 到 N。
问从顶点 1 开始,到其他每个点的最短路有几条。
输入格式
第一行包含 2 个正整数 N,M,为图的顶点数与边数。
接下来 M 行,每行两个正整数 x,y,表示有一条顶点 x 连向顶点 y 的边,请注意可能有自环与重边。
输出格式
输出 N 行,每行一个非负整数,第 i 行输出从顶点 1 到顶点 i 有多少条不同的最短路,由于答案有可能会很大,你只需要输出对 100003 取模后的结果即可。
如果无法到达顶点 i 则输出 0。
数据范围
1 ≤ N ≤ 1 0 5 , 1 ≤ M ≤ 2 × 1 0 5 1≤N≤10^5, 1≤M≤2×10^5 1≤N≤105,1≤M≤2×105
输入样例:
5 7
1 2
1 3
2 4
3 4
2 3
4 5
4 5
输出样例:
1
1
1
2
4
分析:
单 源 最 短 路 , 要 统 计 源 点 到 其 它 所 有 点 的 最 短 路 径 的 条 数 。 单源最短路,要统计源点到其它所有点的最短路径的条数。 单源最短路,要统计源点到其它所有点的最短路径的条数。
距 离 数 组 d i s [ i ] 表 示 源 点 到 i 的 最 短 距 离 , 计 数 数 组 c n t [ i ] 表 示 从 源 点 到 i 点 的 最 短 路 径 条 数 。 距离数组dis[i]表示源点到i的最短距离,计数数组cnt[i]表示从源点到i点的最短路径条数。 距离数组dis[i]表示源点到i的最短距离,计数数组cnt[i]表示从源点到i点的最短路径条数。
因 为 是 无 向 无 权 图 , 可 以 用 b f s 来 求 最 短 路 径 , 时 间 复 杂 度 是 线 性 的 。 因为是无向无权图,可以用bfs来求最短路径,时间复杂度是线性的。 因为是无向无权图,可以用bfs来求最短路径,时间复杂度是线性的。
记 当 前 点 的 编 号 为 t , 与 之 相 连 通 的 点 是 j , 分 两 种 情 况 : 记当前点的编号为t,与之相连通的点是j,分两种情况: 记当前点的编号为t,与之相连通的点是j,分两种情况:
① 、 经 过 点 t 再 到 j 的 距 离 要 更 短 : d i s [ j ] > d i s [ t ] + 1 , 则 ①、经过点t再到j的距离要更短:dis[j]>dis[t]+1,则 ①、经过点t再到j的距离要更短:dis[j]>dis[t]+1,则
更 新 到 j 的 最 短 路 径 : d i s [ j ] = d i s [ t ] + 1 , c n t [ j ] = c n t [ t ] 。 \qquad 更新到j的最短路径:dis[j]=dis[t]+1,cnt[j]=cnt[t]。 更新到j的最短路径:dis[j]=dis[t]+1,cnt[j]=cnt[t]。
② 、 经 过 点 t 再 到 j 的 距 离 与 经 过 其 他 点 到 j 的 最 短 距 离 相 同 : d i s [ j ] = d i s [ t ] + 1 , 则 ②、经过点t再到j的距离与经过其他点到j的最短距离相同:dis[j]=dis[t]+1,则 ②、经过点t再到j的距离与经过其他点到j的最短距离相同:dis[j]=dis[t]+1,则
将 经 过 点 t 再 到 j 的 路 径 条 数 累 加 到 经 过 其 他 点 到 j 的 最 短 路 径 上 : c n t [ j ] + = c n t [ t ] 。 \qquad 将经过点t再到j的路径条数累加到经过其他点到j的最短路径上:cnt[j]+=cnt[t]。 将经过点t再到j的路径条数累加到经过其他点到j的最短路径上:cnt[j]+=cnt[t]。
注意:
无 向 边 注 意 数 组 要 开 两 倍 。 无向边注意数组要开两倍。 无向边注意数组要开两倍。
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int N=1e5+10, M=4e5+10, mod=100003;
int n,m;
int e[M],ne[M],h[N],idx;
int dis[N],cnt[N];
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void bfs()
{
queue<int> Q;
memset(dis,0x3f,sizeof dis);
dis[1]=0;
cnt[1]=1;
Q.push(1);
while(Q.size())
{
int t=Q.front();
Q.pop();
for(int i=h[t];~i;i=ne[i])
{
int j=e[i];
if(dis[j]>dis[t]+1)
{
dis[j]=dis[t]+1;
cnt[j]=cnt[t];
Q.push(j);
}
else if(dis[j]==dis[t]+1)
{
cnt[j]=(cnt[j]+cnt[t])%mod;
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
memset(h,-1,sizeof h);
int a,b;
while(m--)
{
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
bfs();
for(int i=1;i<=n;i++) printf("%d\n",cnt[i]);
return 0;
}