You are given an undirected tree of n vertices.
Some vertices are colored blue, some are colored red and some are uncolored. It is guaranteed that the tree contains at least one red vertex and at least one blue vertex.
You choose an edge and remove it from the tree. Tree falls apart into two connected components. Let’s call an edge nice if neither of the resulting components contain vertices of both red and blue colors.
How many nice edges are there in the given tree?
Input
The first line contains a single integer n (2≤n≤3⋅105) — the number of vertices in the tree.
The second line contains n integers a1,a2,…,an (0≤ai≤2) — the colors of the vertices. ai=1 means that vertex i is colored red, ai=2 means that vertex i is colored blue and ai=0 means that vertex i is uncolored.
The i-th of the next n−1 lines contains two integers vi and ui (1≤vi,ui≤n, vi≠ui) — the edges of the tree. It is guaranteed that the given edges form a tree. It is guaranteed that the tree contains at least one red vertex and at least one blue vertex.
Output
Print a single integer — the number of nice edges in the given tree.
Examples
Input
5
2 0 0 1 2
1 2
2 3
2 4
2 5
Output
1
Input
5
1 0 0 0 2
1 2
2 3
3 4
4 5
Output
4
Input
3
1 1 2
2 3
1 3
Output
0
题意
简单来说,给定一个无向树,每个节点被涂上了蓝色或红色,要么不涂颜色,减去一条边,看两个子树是不是一个包含所有蓝色另一个包含所有红色,计算有多少这样的边。
思路
运用深搜,搜索每一个子树是否包含其中一种颜色而不包含另一种。如果满足的话ans++;
#include<cstdio>
#include<vector>
#include<algorithm>
#include<cstring>
#define maxn 300300
using namespace std;
vector<int>ve[maxn];
int a[maxn],ans,red,blue;
int nred[maxn],nblue[maxn];
void dfs(int u,int pre)
{
if(a[u]==1)nred[u]++;
if(a[u]==2)nblue[u]++;
for(int i=0;i<ve[u].size();i++){
int v=ve[u][i];
if(v==pre)continue;//因为是无向树,所以要记录此节点的根节点,以免重复搜索。
dfs(v,u);
nred[u]+=nred[v];
nblue[u]+=nblue[v];
}
if(nred[u]==0&&nblue[u]==blue)ans++;
else if(nred[u]==red&&nblue[u]==0)ans++;
}
int main()
{
int n;
scanf("%d",&n);
memset(nred,0,sizeof(nred));
memset(nblue,0,sizeof(nblue));
red=blue=0;
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
if(a[i]==1)red++;
if(a[i]==2)blue++;
}
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
ve[u].push_back(v);
ve[v].push_back(u);
}
ans=0;
dfs(1,-1);
printf("%d\n",ans);
return 0;
}