tarjan缩点,重建图;
自己的做法有点麻烦,其实只要记录缩点后出度为0的点里面有多少个点就可以了;
如果存在多个出度为0的点,那么就输出0;
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int MAXN = 100001;
int low[MAXN],num[MAXN],S[MAXN],scc[MAXN],fa[MAXN];
int dfn,n,m,top,ans,cnt,tot;
bool vis[MAXN];
vector<int>G[MAXN],g[MAXN];
void tarjan(int x){
S[++ top] = x;
low[x] = num[x] = ++ dfn;
for(int i = 0;i < G[x].size();i ++){
int v = G[x][i];
if(!num[v]){
tarjan(v);
low[x] = min(low[v],low[x]);
}
else if(!scc[v])
low[x] = min(num[v],low[x]);
}
if(num[x] == low[x]){
cnt ++;
while(true){
int u = S[top --];
scc[u] = cnt;
if(u == x) break;
}
}
return;
}
void dfs(int x){
vis[x] = 1;
for(int i = 0;i < g[x].size();i ++){
int v = g[x][i];
if(!vis[v]){
tot ++;
dfs(v);
}
}
return;
}
void solve(){
int f,t;
cin >> n >> m;
for(int i = 1;i <= m;i ++ ){
scanf("%d%d",&f,&t);
G[f].push_back(t);
}
for(int i = 1;i <= n;i ++){
if(!num[i])
tarjan(i);
}
for(int i = 1;i <= n;i ++){
for(int j = 0;j < G[i].size();j ++){
int v = G[i][j];
if(scc[v] != scc[i]){
g[scc[v]].push_back(scc[i]);
}
}
}
for(int i = 1;i <= n;i ++){
for(int j = 1;j <= cnt;j ++) vis[j] = 0;
tot = 0;
dfs(scc[i]);
fa[scc[i]] = tot;
}
for(int i = 1;i <= n;i ++)
if(fa[scc[i]] == cnt - 1)
ans ++;
cout << ans << endl;
return;
}
int main(){
solve();
return 0;
}
改进后
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int MAXN = 100001;
int low[MAXN],num[MAXN],S[MAXN],scc[MAXN],fa[MAXN],chu[MAXN];
int dfn,n,m,top,ans,cnt,tot;
vector<int>G[MAXN];
void tarjan(int x){
S[++ top] = x;
low[x] = num[x] = ++ dfn;
for(int i = 0;i < G[x].size();i ++){
int v = G[x][i];
if(!num[v]){
tarjan(v);
low[x] = min(low[v],low[x]);
}
else if(!scc[v])
low[x] = min(num[v],low[x]);
}
if(num[x] == low[x]){
cnt ++;
while(true){
int u = S[top --];
scc[u] = cnt;
if(u == x) break;
}
}
return;
}
void solve(){
int f,t;
cin >> n >> m;
for(int i = 1;i <= m;i ++ ){
scanf("%d%d",&f,&t);
G[f].push_back(t);
}
for(int i = 1;i <= n;i ++){
if(!num[i])
tarjan(i);
}
for(int i = 1;i <= n;i ++){
for(int j = 0;j < G[i].size();j ++){
int v = G[i][j];
if(scc[v] != scc[i]){
chu[scc[i]] ++;
}
}
}
int pos = 0,ccnt = 0;
for(int i = 1;i <= cnt;i ++){
if(!chu[i])
pos = i,ccnt ++;
}
if(!pos || ccnt >= 2) cout <<"0";
else{
for(int i = 1;i <= n;i ++){
if(scc[i] == pos) ans ++;
}
cout << ans <<endl;
}
return;
}
int main(){
solve();
return 0;
}