详解:dsu on tree(树上启发式合并)算法总结+习题
经典例题:https://vjudge.net/problem/CodeForces-600E
题意:一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
dsu on tree简介:
在O(N^2)的暴力做法中,我们用cnt记录每种颜色出现的次数,对于每个结点,遍历这棵子树上的所有结点找到答案,然后清空cnt数组。dsu on tree中,当这个结点是它父亲的重儿子时,我们就先找父亲其他儿子(轻儿子)的答案,最后遍历这棵子树,并保留这颗子树cnt数组记录,那么我们回溯找父亲结点的答案时就不需要再遍历一遍它的重儿子了。
时间复杂度O(log n):
我们考虑每个结点被统计的次数,那么从它到跟结点的路径中,有一条轻边,就会被统计一次。因为是轻儿子就会被删除记录,然后找这个轻儿子的父亲结点的答案时,又会被统计一次。而由重链剖分可以保证每个结点到跟结点的路径上轻便不会超过log n条。
#include<algorithm>
#include<iostream>
using namespace std;
#define ll long long
#define ull unsigned long long
#define PII pair<int,int>
#define mid ((l + r)>>1)
#define chl (root<<1)
#define chr (root<<1|1)
#define lowbit(x) ( x&(-x) )
const int manx = 1e5 + 10;
const int INF = 2e9;
const int mod = 1e4+7;
int color[manx],hson[manx],sz[manx];
//color每个结点的颜色、hson重儿子
//sz每个子树的大小
ll cnt[manx],maxCnt,sum;
//cnt记录每种颜色出现次数,maxCnt记录每种颜色出现次数的最大值,
//sum维护每个结点的答案
ll ans[manx];//记录每个结点的答案
int cou,head[manx];
struct node
{
int e,bf;
}edge[manx<<2];
void init(int n)
{
cou=0;
for(int i=1;i<=n;i++){
head[i]=-1;
hson[i]=0;
sz[i]=1;
}
sz[0]=-1;
}
void add1(int s,int e)
{
edge[cou]=node{e,head[s]};
head[s]=cou++;
}
void dfs_list(int s,int fa)//找每个结点重儿子
{
for(int i=head[s];~i;i=edge[i].bf){
int e=edge[i].e;
if(e==fa)continue;
dfs_list(e,s);
sz[s]+=sz[e];
if(sz[e]>sz[hson[s]])
hson[s]=e;
}
}
void add(int f,int s,int fa,int k)
//k为 1时,加上 这棵子树中每个颜色出现的次数
//k为-1时,减去 这棵子树中每个颜色出现的次数
//f表示这棵子树根结点的重儿子(k=1时,不需要遍历它
{
cnt[color[s]]+=k;
if(k>0&&cnt[color[s]]>=maxCnt){//维护maxCnt和sum
if(cnt[color[s]]>maxCnt)sum=0,maxCnt=cnt[color[s]];
sum+=color[s];
}
for(int i=head[s];~i;i=edge[i].bf){
int e=edge[i].e;
if((e==f&&k>0)||e==fa)continue;
add(f,e,s,k);
}
}
void dfs_Dot(int s,int fa,int keep)
{
for(int i=head[s];~i;i=edge[i].bf){
int e=edge[i].e;
if(e!=fa&&e!=hson[s])
dfs_Dot(e,s,0);//1、先找轻儿子的答案,不记录cnt
}
if(hson[s])dfs_Dot(hson[s],s,1);//2、找重儿子的答案,记录cnt
add(hson[s],s,fa,1);//3、再遍历一遍轻儿子,找结点s的答案,记录cnt
ans[s]=sum;
if(!keep)add(0,s,fa,-1),sum=maxCnt=0;//结点s不是重儿子(keep==0)时,清空cnt,删除记录
}
int main()
{
int n,s,e;
scanf("%d",&n);
init(n);
for(int i=1;i<=n;i++)
scanf("%d",&color[i]);
for(int i=1;i<n;i++){
scanf("%d%d",&s,&e);
add1(s,e);
add1(e,s);
}
dfs_list(1,0);
dfs_Dot(1,0,1);
for(int i=1;i<=n;i++)
printf("%lld%c",ans[i],i==n?'\n':' ');
return 0;
}