题解:
dsu on tree
对于结点i来说,步骤为:
递归轻儿子,不保留贡献。
递归重儿子,保留贡献。
统计当前结点及所有轻儿子的贡献。
拿个样例来解释算法的流程
比如说样例2:
如图所示,有多个结点的大小相同,重儿子应该是节点数目最多的那个子节点,那么我们假设1号结点的重儿子是4号结点。
递归轻儿子顺序与输入顺序有关系。
因为空间有限,我们无法把所有结点的状态给存起来,所以用的是1个cnt数组。
假设我们计算14号结点后要去计算15号结点,但是这两个结点用的是同一个cnt数组,那么当我们计算完14号结点的答案后,我们要清空14号结点对cnt数组造成的影响后,才可以去计算15号结点。
同样的方法当我们把2 、 14 、 3 、 15这四个结点计算完后,我们再去计算重儿子4号结点。
但是与轻儿子不同的是,当我们计算完4号重儿子结点的信息后,cnt数组可以不用清空。
为什么?
当我们把1号结点所有的子节点的答案都计算完成后,那就意味着2-15号结点的答案都已经有了,那么还差1号结点没有计算,我们计算1号结点的时候需要遍历1号结点包括的所有结点(也就是2-15号结点),那假设我们没有把4号结点(重儿子)以及他子树的信息清空,那么我们是不是就可以少遍历4 、11 、12、 13这四个结点,只需要去遍历2,5,6,7,14,3,8,9,10,15这些结点。
从而减少了时间复杂度,(玄学)
case2输入:
15
1 2 3 1 2 3 3 1 1 3 2 2 1 2 3
1 2
1 3
1 4
1 14
1 15
2 5
2 6
2 7
3 8
3 9
3 10
4 11
4 12
4 13
case2输出:
6 5 4 3 2 3 3 1 1 3 2 2 1 2 3
#include <bits/stdc++.h>
using namespace std;
const int maxn=2e5+10;
#define int long long
struct node{
int to,next;
}edge[maxn];
int head[maxn];
int tot=0;
void add(int u,int v){
edge[tot].to=v;
edge[tot].next=head[u];
head[u]=tot++;
}
int col[maxn];
int sz[maxn],son[maxn];
void dfs1(int x,int fa){ //重链剖分找重儿子
sz[x]=1;
for(int i=head[x];~i;i=edge[i].next){
int v=edge[i].to;
if(v==fa) continue;
dfs1(v,x);
sz[x]+=sz[v];
if(sz[v]>sz[son[x]]) son[x]=v;
}
}
int flag,cnt[maxn]; //cnt表示当前统计结点以及他子树颜色的数量 flag用来标记重儿子
int imax,sum; //imax用来记录当前结点最大值 sum记录答案(比如说有很多个颜色数目相同的颜色,答案就是颜色的编号和)
void count(int u,int fa,int val){
cnt[col[u]]+=val;
if(cnt[col[u]]>imax){
imax=max(cnt[col[u]],imax);
sum=col[u];
} else if(cnt[col[u]]==imax){ //有多个颜色数目相同的颜色
sum+=col[u];
}
for(int i=head[u];~i;i=edge[i].next){ //暴力统计
int v=edge[i].to;
if(v==fa||v==flag) continue;
count(v,u,val);
}
}
int ans[maxn];
void dsu(int u,int fa,bool keep){
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
if(v==fa||v==son[u]) continue;
dsu(v,u, false);
}
if(son[u]){ //如果是重儿子,keep=1,这样就不会删除重儿子上的信息了
dsu(son[u],u,true);
flag=son[u];
}
count(u,fa,1); //暴力统计
flag=0;
ans[u]=sum; //记录答案
if(!keep){ //如果不是重儿子 ,暴力删除轻儿子上面的信息
count(u,fa,-1);
sum=imax=0;
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
memset(head,-1,sizeof(head));
int n;
cin>>n;
for(int i=1;i<=n;i++) cin>>col[i];
for(int i=0;i<n-1;i++){
int u,v;
cin>>u>>v;
add(u,v);add(v,u);
}
dfs1(1,0);
//cout<<son[1]<<endl;
dsu(1,0,0);
for(int i=1;i<=n;i++){
cout<<ans[i]<<" ";
}
return 0;
}