Bobo has a tree with n vertices numbered by 1,2,…,n and (n-1) edges. The i-th vertex has color c i, and the i-th edge connects vertices a i and b i.
Let C(x,y) denotes the set of colors in subtree rooted at vertex x deleting edge (x,y).
Bobo would like to know R_i which is the size of intersection of C(a i,b i) and C(b i,a i) for all 1≤i≤(n-1). (i.e. |C(a i,b i)∩C(b i,a i)|)
Input
The input contains at most 15 sets. For each set:
The first line contains an integer n (2≤n≤10 5).
The second line contains n integers c 1,c 2,…,c n (1≤c_i≤n).
The i-th of the last (n-1) lines contains 2 integers a i,b i (1≤a i,b i≤n).
Output
For each set, (n-1) integers R 1,R 2,…,R n-1.
Sample Input
4
1 2 2 1
1 2
2 3
3 4
5
1 1 2 1 2
1 3
2 3
3 5
4 5
Sample Output
1
2
1
1
1
2
1
虽然看出来需要进行树启发式合并,但是同时需要处理两个树让我感到迷惑。看了一下题解才发现,如果用一个map记录整体每种颜色有多少节点,那么我们遍历子树就可以得到剩下的子树的信息,从而判断是否两个子树都含有相同的颜色,是否将颜色加入答案。
判断有两种:
- 如果当前树中原本没有这种颜色,并且加入这个子树以后当前树中没有含有这种颜色的所有节点,则将这种颜色加入到答案中。
- 如果当前树中原本含有这种颜色,并且加入这个子树以后当前树中含有这种颜色的所有节点,则将这个颜色从答案中去掉。
AC代码:
参考的了大佬的博客(传送门),但也有我自己的一些理解:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<climits>
#include<algorithm>
#include<ctime>
#include<cstdlib>
#include<queue>
#include<set>
#include<map>
#include<cmath>
using namespace std;
const int MAXN=1e5+5;
map<int,int> All,H[MAXN];
map<int,int>::iterator it;
vector<pair<int,int> > E[MAXN];
int n,c[MAXN];
int ans[MAXN],sum[MAXN];
void Merge(int x,int y,int id1,int id2)
{
if(H[x].size()<H[y].size())
{
swap(H[x],H[y]);
sum[id1]=sum[id2];
}
for(it=H[y].begin();it!=H[y].end();++it)
{
int color=it->first; int num=it->second;
if(!H[x].count(color))//不含有这种颜色
{
if(num<All[color]) ++sum[id1];
H[x][color]+=num;
}
else
{
H[x][color]+=num;
if(H[x][color]==All[color]) --sum[id1];
}
}
}
void Dfs(int cur,int pre,int eid)
{
H[cur][c[cur]]=1;
if(H[cur][c[cur]]<All[c[cur]]) sum[eid]=1;
else sum[eid]=0;
for(int i=0;i<E[cur].size();i++)
{
int to=E[cur][i].second;
if(to==pre) continue;
int id=E[cur][i].first;
Dfs(to,cur,id);
Merge(cur,to,eid,id);
}
ans[eid]=sum[eid];
}
int main()
{
while(~scanf("%d",&n))
{
All.clear();
for(int i=0;i<=n;++i)
{
H[i].clear(); E[i].clear();
ans[i]=sum[i]=0;
}
for(int i=1;i<=n;i++)
{
scanf("%d",&c[i]);
++All[c[i]];
}
for(int i=1,u,v;i<n;++i)
{
scanf("%d%d",&u,&v);
E[u].push_back(make_pair(i,v));
E[v].push_back(make_pair(i,u));
}
Dfs(1,0,0);
for(int i=1;i<n;++i)
{
printf("%d\n",ans[i]);
}
}
return 0;
}