Description
给出一张DAG G=<V,E> |V|=n,|E|=m
定义一个点x为important的当且仅当对于每个点y!=x,满足x能到达y,或者y能到达x
定义一个点x为semi-important的当且仅当删去一个点后x为important的
求important的和semi-important的点数
n,m<=300000
Solution
显然我们需要对每个点求出in[x]和out[x]表示能到x的点数和x能到的点数
但是直接求显然是不行的,我们需要一点方法
找出G中的最长路P,我们可以证明所有的important点都在P上
考虑反证:设u是important的却不在P上,那么P会分成能到u和u能到两部分,设分界点为pi
那么我们可以找到p1->p2->…->pi->u->pi+1->…->pk这一条路,P就不是最长路,与假设矛盾
对一条链上的点求in和out是很简单的
我们来考虑semi-important点,同样我们可以证明,如果点u是semi-important的,那么一定有一条p1->p2->…->pi-1->u->pi+1->…->pk的路径
且对于每个i这样的u只有一个
那么对于这样的点的in和out也很好求
复杂度O(n+m)
Code
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep_0(i,a) for(int i=0;i<to[a].size();i++)
#define rep_1(i,a) for(int i=0;i<from[a].size();i++)
using namespace std;
int read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
int x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
const int N=3e5+5;
int n,m,f[N],q[N],nxt[N],deg[N],p[N],in[N],out[N],sta[N],semi[N],top;
bool vis[N];
vector<int> to[N],from[N];
#define pb(a) push_back(a)
int find_0(int x) {
if (vis[x]) return 0;
int cnt=0;vis[x]=1;sta[++top]=x;
rep_0(i,x) cnt+=find_0(to[x][i]);
return cnt+1;
}
int find_1(int x) {
if (vis[x]) return 0;
int cnt=0;vis[x]=1;sta[++top]=x;
rep_1(i,x) cnt+=find_1(from[x][i]);
return cnt+1;
}
int main() {
n=read();m=read();
if (n==1) {
puts("1");
return 0;
}
fo(i,1,m) {
int x=read(),y=read();deg[y]++;
to[x].pb(y);from[y].pb(x);
}
int hd=0,tl=0;
fo(i,1,n) if (!deg[i]) q[++tl]=i;
while (hd<tl) {
int x=q[++hd];
rep_0(i,x) {
int y=to[x][i];
if (!(--deg[y])) q[++tl]=y;
}
}
fd(i,n,1) {
int x=q[i];
rep_0(i,x) {
int y=to[x][i];
if (f[y]+1>f[x]) f[x]=f[y]+1,nxt[x]=y;
}
}
int mx=0,id=0;
fo(i,1,n) if (f[i]>mx) mx=f[i],id=i;
int x=p[p[0]=1]=id;while (nxt[x]) p[++p[0]]=x=nxt[x];
int cnt=0;
rep_1(j,p[2]) {
int x=from[p[2]][j];
if (x!=p[1]) semi[1]=x,cnt++;
}
if (cnt>1) semi[1]=0;
cnt=0;
rep_0(j,p[p[0]-1]) {
int x=to[p[p[0]-1]][j];
if(x!=p[p[0]]) semi[p[0]]=x,cnt++;
}
if (cnt>1) semi[p[0]]=0;
fo(i,2,p[0]-1) {
rep_0(j,p[i-1]) vis[to[p[i-1]][j]]=1;
int cnt=0;
rep_1(j,p[i+1]) {
int x=from[p[i+1]][j];
if (x!=p[i]&&vis[x]) cnt++,semi[i]=x;
}
if (cnt>1) semi[i]=0;
rep_0(j,p[i-1]) vis[to[p[i-1]][j]]=0;
}
memset(vis,0,sizeof(vis));top=0;
fd(i,p[0],1) {
if (semi[i]) {
out[semi[i]]=out[p[i+1]];
int now=top;out[semi[i]]+=find_0(semi[i]);
while (top>now) vis[sta[top--]]=0;
}
out[p[i]]=out[p[i+1]]+find_0(p[i]);
}
memset(vis,0,sizeof(vis));top=0;
fo(i,1,p[0]) {
if (semi[i]) {
in[semi[i]]=in[p[i-1]];
int now=top;in[semi[i]]+=find_1(semi[i]);
while (top>now) vis[sta[top--]]=0;
}
in[p[i]]=in[p[i-1]]+find_1(p[i]);
}
int ans=0;
fo(i,1,p[0]) {
if (in[p[i]]+out[p[i]]>=n) ans++;
if (in[semi[i]]+out[semi[i]]>=n) ans++;
}
printf("%d\n",ans);
return 0;
}