虽然校赛打的超级烂,但是还是因此学了树形dp和树状数组的。
记录蓝色节点和红色节点的数目。然后从根节点dfs,如果子树的蓝色==总蓝色并且红色==0,那么cnt++;如果子树的红色==总红色并且红色==0,那么cnt++。
就这两种情况。每次dfs之后要更新当前节点的红色蓝色数。超级暴力,不知道为啥这么大的数据还可以过。
1 #include <iostream> 2 #include <cstring> 3 #include <string> 4 #include <map> 5 #include <set> 6 #include <algorithm> 7 #include <fstream> 8 #include <cstdio> 9 #include <cmath> 10 #include <stack> 11 #include <queue> 12 using namespace std; 13 const double Pi=3.14159265358979323846; 14 typedef long long ll; 15 const int MAXN=300000+5; 16 const int dx[5]={0,0,0,1,-1}; 17 const int dy[5]={1,-1,0,0,0}; 18 const int INF = 0x3f3f3f3f; 19 const int NINF = 0xc0c0c0c0; 20 const ll mod=1e9+7; 21 int c[MAXN]; 22 int red,blue; 23 vector<int> G[MAXN]; 24 int vis[MAXN]; 25 struct node{ 26 int red,blue; 27 }dp[MAXN]; 28 int cnt=0; 29 void dfs(int v) 30 { 31 vis[v]=1; 32 int to; 33 if(c[v]==1) dp[v].red++; 34 else if(c[v]==2) dp[v].blue++; 35 for(int i=0;i<G[v].size();i++) 36 { 37 to=G[v][i];if(vis[to]) continue; 38 dfs(to); 39 dp[v].blue+=dp[to].blue; 40 dp[v].red+=dp[to].red; 41 if(dp[to].blue==blue&&dp[to].red==0) cnt++; 42 else if(dp[to].red==red&&dp[to].blue==0) cnt++; 43 } 44 } 45 int main() 46 { 47 int n;cin>>n; 48 for(int i=1;i<=n;i++) 49 { 50 cin>>c[i]; 51 if(c[i]==1) red++; 52 else if(c[i]==2) blue++; 53 } 54 for(int i=1;i<=n-1;i++) 55 { 56 int a,b;scanf("%d%d",&a,&b); 57 G[a].push_back(b); 58 G[b].push_back(a); 59 } 60 dfs(1); 61 cout <<cnt<<endl; 62 63 }