代码如下:
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int maxn=50000+5;
int n,m,ans,hd[maxn],to[maxn],nxt[maxn],tot;
int mp[maxn],dfn[maxn],low[maxn],cnt,C;//Tarjan,mp[]&C缩点
int stk[maxn],tail=0;
bool instk[maxn];
int du[maxn];
inline int read(){
char ch=getchar();
int f=1,x=0;
while(!(ch>='0'&&ch<='9')){
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
x=x*10+ch-'0',ch=getchar();
return f*x;
}
inline void add(int x,int y){
to[tot]=y;
nxt[tot]=hd[x];
hd[x]=tot++;
}
void tarjan(int root){
dfn[root]=low[root]=++cnt;
stk[++tail]=root;//root入栈
instk[root]=1;
for(int i=hd[root];i!=-1;i=nxt[i]){
if(dfn[to[i]]==0)
tarjan(to[i]),low[root]=min(low[root],low[to[i]]);
else if(instk[to[i]])
low[root]=min(low[root],dfn[to[i]]);
}
if(low[root]==dfn[root]){
C++;
int tmp;
do{
tmp=stk[tail--];
instk[tmp]=0;
mp[tmp]=C;
}while(tmp!=root);
}
}
signed main(void){
n=read(),m=read();
memset(hd,-1,sizeof(hd)),tot=0;
for(int i=1,x,y;i<=m;i++)
x=read(),y=read(),add(x,y);
for(int i=1;i<=n;i++)
if(dfn[i]==0)
tarjan(i);//如果未访问过,dfs
for(int i=1;i<=n;i++)
for(int j=hd[i];j!=-1;j=nxt[j])//遍历临接表
if(mp[i]!=mp[to[j]])
du[mp[i]]++;
int sum=0,tmp=0;
for(int i=1;i<=C;i++)
if(du[i]==0)
sum++,tmp=i;
if(sum==1)
for(int i=1;i<=n;i++)
ans+=mp[i]==tmp;
cout<<ans<<endl;
return 0;
}