可达性统计
题目描述
核心思路
设从点
x
x
x出发能到达的所有点的集合为f(x)
,设
x
x
x连接的所有边为
y
1
,
y
2
,
⋯
,
y
k
y_1,y_2,\cdots,y_k
y1,y2,⋯,yk,则有如下式子:
f ( x ) = x f(x)=x f(x)=x ∪ \cup ∪ ( ⋃ 1 ≤ i ≤ k f ( y i ) ) (\bigcup \limits _{1\leq i\leq k}f(y_i)) (1≤i≤k⋃f(yi))
即:从点 x x x出发能到达的所有点就是从它连接的所有点出发能到达的所有点加上 x x x本身
考虑拓扑排序。在一张图的拓扑序中,对每条边 ( x , y ) (x,y) (x,y)总有: x x x在 y y y之前。根据上述的这个式子,我们发现要想求出 f ( x ) f(x) f(x),必须要先求出 f ( y i ) f(y_i) f(yi),而由于 y i y_i yi是 x x x之后的节点,因此我们可以按照拓扑排序的倒序来进行计算。
我们考虑将每个点 f ( x ) f(x) f(x)记作一个 N N N位二进制数,其中第 i i i位为1,表示可以到达 i i i;第 i i i位为0,表示不可以到达 i i i。那么 f ( x ) f(x) f(x)中1的数量就是从 x x x出发能到达的所有点的数量。
我们来考虑一下数据规模,在最坏情况下,拓扑排序是一条链,对于第一个节点, f ( 1 ) = n f(1)=n f(1)=n,对于第二个节点, f ( 2 ) = n − 1 f(2)=n-1 f(2)=n−1, ⋯ \cdots ⋯,对于第 n n n个节点, f ( n ) = 1 f(n)=1 f(n)=1。那么总数为 f ( 1 ) + f ( 2 ) + ⋯ + f ( n ) = n + ( n − 1 ) + ⋯ + 1 = n ( n − 1 ) 2 f(1)+f(2)+\cdots+f(n)=n+(n-1)+\cdots+1=\dfrac {n(n-1)}{2} f(1)+f(2)+⋯+f(n)=n+(n−1)+⋯+1=2n(n−1), n n n最大取到 30000 30000 30000,所以最坏的总数约为 4.5 4.5 4.5亿,如果用二维数组来存储的话,那么内存空间就会爆炸。因此需要把 N N N位的二进制数压缩到一个int中,这需要借助STL中的bitset。这样空间复杂度将减少为原来的 1 32 \dfrac {1}{32} 321,此时变为 14 , 062 , 500 14,062,500 14,062,500是可以接受的。
bitset<N>f
这里的
f
f
f其实是一维数组;
bitset<N>f[N]
,这里的
f
f
f其实是二维数组
代码
#include<iostream>
#include<cstring>
#include<bitset>
#include<algorithm>
using namespace std;
const int N=30010,M=30010;
int n,m;
int h[N],e[M],ne[M],idx;
int d[N]; //记录每个点的入队
int q[N]; //记录拓扑序列
bitset<N>f[N];
//从点a向点b连一条有向边
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void topsort()
{
int hh=0,tt=-1;
for(int i=1;i<=n;i++)
{
if(!d[i])
{
q[++tt]=i;
}
}
while(hh<=tt)
{
int t=q[hh++];
for(int i=h[t];~i;i=ne[i])
{
int j=e[i];
if(--d[j]==0)
q[++tt]=j;
}
}
}
int main()
{
memset(h,-1,sizeof h);
scanf("%d%d",&n,&m);
while(m--)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b);
d[b]++;
}
//求出拓扑序列
topsort();
//对拓扑序列进行倒序遍历
for(int i=n-1;i>=0;i--)
{
int x=q[i];
//从x这个点出发可以到达x自身 因此f[x][x]=1
f[x][x]=1;
//遍历x的所有邻接点
//f[x]|=f(yi)
for(int j=h[x];~j;j=ne[j])
{
int y=e[j]; //取出x邻接点的编号y
f[x]|=f[y];
}
}
//f[i].count()表示f[i]中1的个数
//其实也就是从i出发能到达的所有点的数量
for(int i=1;i<=n;i++)
printf("%d\n",f[i].count());
return 0;
}