理解和注意点都写在代码注释了.....
#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
typedef pair<int,int> pii;
const int inf=0x3f3f3f3f;
const int N=1e5+10;
const int mod=1e9+7;
#define fi first
#define se second
int n,val[N];
vector<int> ve[N];
int son[N],siz[N];
void dfs1(int x,int f){
siz[x]=1;
for(auto y:ve[x]){
if(y==f)
continue;
dfs1(y,x);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]])
son[x]=y;
}
}
int cnt[N];//cnt存“当前”子树某种颜色数量
int ans[N],sum;//ans存每个节点子树的答案,sum用来临时存”当前“子树答案,最后用sum来更新ans数组
int flag,maxc;//flag来标记“当前”节点的重儿子是谁,maxc来记录“当前”子树最大值
void calc(int u,int f,int value){
cnt[val[u]]+=value;//value正负表示增减贡献
if(cnt[val[u]]>maxc){//更新最大值和颜色和
maxc=cnt[val[u]];
sum=val[u];
}
else if(cnt[val[u]]==maxc)
sum+=val[u];
for(auto v:ve[u]){
if(v==f||v==flag)//不能是son[u],因为轻儿子的重儿子子树贡献是要算的,
continue;//不进去轻儿子的重儿子子树贡献会少算
calc(v,u,value);
}
}
//dsu on tree板子
void dfs2(int u,int f,bool keep){//f表示父亲,keep表示是否保留
//第一步,先解决轻儿子及其子树的贡献,再删贡献,在这里处理了轻儿子子树的所有答案了,即ans轻儿子已经更新完了
for(auto v:ve[u]){
if(v==f||v==son[u])
continue;//重儿子留到后面
dfs2(v,u,0);//轻儿子不保留,keep为0
}
//第二步,解决重儿子及其子树,不删贡献,更新重儿子的所有ans
if(son[u]){
dfs2(son[u],u,1);
flag=son[u];//一定要用flag来记录重儿子是谁,
} //如果用son[u]来表示,在calc后面的son[u]不等于当前的flag
//轻儿子子树也有轻儿子子树的重儿子,如果少算了轻儿子子树的重儿子
//就会少算了一些轻儿子子树里的贡献
//由于calc传的是当前节点u,在遍历calc里u节点的儿子节点里,
//重儿子子树不进去,且重儿子的贡献已经被之前的重儿子的重儿子保留下来了
//第三步,暴力将轻儿子的贡献都加到重儿子上,然后重儿子贡献就是总贡献,
//保留重儿子贡献,这样父节点就可以不用算这重儿子的贡献
calc(u,f,1);//再次遍历所有轻儿子,因为第一个循环已经处理好了所有轻儿子的ans,现在将这些轻儿子的ans加到重儿子的贡献里,用重儿子的贡献表示整棵树的贡献(重儿子的ans也已经算好了)
//1表示算贡献,后面-1才是删除贡献
flag=0; //如果当前的keep是0的话,整个子树的贡献都要删掉
//如果不把flag清空,后面删除的时候该点的重儿子子树贡献就不会被删了
ans[u]=sum;//更新答案
//如果keep为0表示当前这个子树的贡献算完是要删掉的
if(!keep){
calc(u,f,-1);//都是传入当前节点,然后遍历当前节点的所有儿子
//只是算贡献的时候不用算重儿子的贡献(已经被保留算过了)
//如果当前节点的贡献是要被删的话,重儿子子树也要删的
//如果keep是1,那么这个子树贡献都要保留
//包括当前节点的轻儿子,因为当前儿子是对于父类前面的重儿子子树
sum=maxc=0;//删完所有子树的贡献后sum和maxc也要恢复初始值,表示子树的贡献删完了
}
}
void solve(){
cin>>n;
for(int i=1;i<=n;i++)
cin>>val[i];
for(int i=1;i<n;i++){
int a,b;
cin>>a>>b;
ve[a].push_back(b);
ve[b].push_back(a);
}
dfs1(1,0);//找重链
dfs2(1,0,0);//clac ,这里keep0,1都可以
for(int i=1;i<=n;i++)
cout<<ans[i]<<' ';
}
signed main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
int t=1;
//cin>>t;
while(t--){
solve();
}
return 0;
}